import pickle
import utils
import numpy as np
import timeit
from pynndescent import NNDescent
import FalconnLite
import csv
import os
import sys

# ==============================================================
# LOG SETUP
# ==============================================================

log_name = "approx_join_k_sweep_repeat.log"
log_file = open(log_name, "w")

class Tee:
    def __init__(self, *files):
        self.files = files
    def write(self, data):
        for f in self.files:
            f.write(data)
            f.flush()
    def flush(self):
        for f in self.files:
            f.flush()

sys.stdout = Tee(sys.stdout, log_file)
sys.stderr = Tee(sys.stderr, log_file)

print(f"[LOG] Writing output to {log_name}")

# ==============================================================
# Dataset configurations
# ==============================================================
dataset_paths = {
    "mnist": {
        "bin_file": f"{DATA_DIR}/mnist_X_60K_784.bin",
        "gt_file": f"{GT_DIR}/Mnist60K_Cosine_k_100_indices.npy",
    }
}

# ==============================================================
# Fixed experiment parameters
# ==============================================================

method = "approx_join"
D = 512
iProbe = 3
qProbe = 1
dist = "Cosine"
verbose = True
n_threads = 32

trial_seed = 100

# ✅ k sweep
k_list = [20, 50]

# ✅ m fixed
topm_list = [25]

# ✅ n_repeat sweep
n_repeats_list = [1, 4, 8, 12, 16]

# NNDescent iterations
iters = [10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0]

# ==============================================================
# Main Loop
# ==============================================================

print(f"\n============ Running single seed {trial_seed} ============\n")

for dataset_name, data_info in dataset_paths.items():

    print(f"\n======== Dataset: {dataset_name} | Seed={trial_seed} ========\n")

    bin_file = data_info["bin_file"]
    n = data_info["n"]
    d = data_info["d"]
    gt_file = data_info["gt_file"]

    # ---------------- Load dataset ----------------
    X = utils.mmap_bin(bin_file, n, d)

    # ---------------- Normalize (cosine) ----------------
    norms = np.linalg.norm(X, axis=1, keepdims=True)
    norms[norms == 0] = 1
    X = X / norms

    output_dir = f"{dataset_name}_result/Sensitivity"
    os.makedirs(output_dir, exist_ok=True)

    # ==========================================================
    # k sweep
    # ==========================================================
    for k in k_list:

        print(f"\n---- Running k = {k} ----")

        csv_file = f"{output_dir}/{method}_k{k}_nrepeat_seed{trial_seed}.csv"

        if not os.path.isfile(csv_file):
            with open(csv_file, mode="w", newline="") as f:
                writer = csv.writer(f)
                writer.writerow([
                    "Dataset", "Method", "Init Time(ms)",
                    "Construction distance cost per point",
                    "NNDescent Time(ms)", "Distance computation",
                    "Total Time(ms)", "NNDescent Iterations",
                    "NNDescent Accuracy", "Number of random vector",
                    "TopK", "iProbe", "top_m", "qProbe",
                    "distance", "centering"
                ])

        exact_kNN = np.load(gt_file)[:, 1:k+1]

        # ======================================================
        # n_repeat × m
        # ======================================================
        for n_repeats in n_repeats_list:
            for top_m in topm_list:

                print(f">> k={k}, n_repeats={n_repeats}, top_m={top_m}")

                index = FalconnLite.FalconnLite(n, d)
                index.set_params(
                    n_proj=D,
                    iProbe=iProbe,
                    top_m=top_m,
                    qProbe=qProbe,
                    distance=dist,
                    verbose=verbose,
                    n_threads=n_threads,
                    seed=-1
                )
                index.centering = True

                # ---------------- Init ----------------
                t_init_start = timeit.default_timer()
                indices, _ = index.approx_join(X, topK=k, n_repeats=n_repeats)
                t_init_end = timeit.default_timer()
                init_time_ms = (t_init_end - t_init_start) * 1000

                # ---------------- NNDescent ----------------
                for t in iters:
                    t_nn_start = timeit.default_timer()

                    index1 = NNDescent(
                        X,
                        n_neighbors=k,
                        random_state=trial_seed,
                        tree_init=False,
                        init_graph=indices,
                        metric="cosine",
                        n_iters=t,
                        n_jobs=n_threads
                    )
                    knn_indices, _ = index1.neighbor_graph

                    t_nn_end = timeit.default_timer()
                    nn_time_ms = (t_nn_end - t_nn_start) * 1000
                    total_time_ms = init_time_ms + nn_time_ms

                    acc = utils.getAcc(knn_indices, exact_kNN)

                    with open(csv_file, mode="a", newline="") as f:
                        writer = csv.writer(f)
                        writer.writerow([
                            dataset_name,
                            f"{n_repeats} {method}",
                            f"{init_time_ms:.2f}",
                            0,
                            f"{nn_time_ms:.2f}",
                            getattr(index1, "counter", 0) / n,
                            f"{total_time_ms:.2f}",
                            t,
                            f"{acc:.4f}",
                            D,
                            k,
                            iProbe,
                            top_m,
                            qProbe,
                            dist,
                            index.centering
                        ])

        print(f"\n---- Finished k={k} ----\n")

print("\n[LOG] Experiment finished.")
log_file.close()
