import pickle
import utils
import numpy as np
import timeit
from pynndescent import NNDescent
import FalconnLite
import csv
import os
import sys

# ==============================================================
# LOG SETUP  （stdout + stderr → log file）
# ==============================================================

log_name = "approx_join_D.log"
log_file = open(log_name, "w")

class Tee:
    def __init__(self, *files):
        self.files = files

    def write(self, data):
        for f in self.files:
            f.write(data)
            f.flush()

    def flush(self):
        for f in self.files:
            f.flush()

sys.stdout = Tee(sys.stdout, log_file)
sys.stderr = Tee(sys.stderr, log_file)

print(f"[LOG] Writing output to {log_name}")

# ==============================================================
# Dataset configurations
# ==============================================================



dataset_paths = {
    "mnist": {
        "bin_file": f"{DATA_DIR}/mnist_X_60K_784.bin",
        "gt_file": f"{GT_DIR}/Mnist60K_Cosine_k_100_indices.npy",
    }
}


# ==============================================================
# Fixed experiment parameters
# ==============================================================

method = "approx_join"
D = 512
iProbe = 3
qProbe = 1
dist = "Cosine"
verbose = True
n_threads = 32

iters = [10,9,8,7,6,5,4,3,2,1,0]
n_repeats_list = [8]
k_list = [20]
seed_list = [92]

# ==============================================================
# Main Loop: seeds × datasets × parameters
# ==============================================================

for trial_seed in seed_list:

    print(f"\n\n============ Running seed {trial_seed} ============\n")

    for dataset_name, data_info in dataset_paths.items():

        print(f"\n\n======== Dataset: {dataset_name} | Seed={trial_seed} ========\n")

        bin_file = data_info["bin_file"]
        n = data_info["n"]
        d = data_info["d"]
        gt_file = data_info["gt_file"]

        # ---------------- Load dataset ----------------
        if dataset_name == "Word":
            X = np.load(bin_file)
        else:
            X = utils.mmap_bin(bin_file, n, d)

        # ---------------- Normalize (cosine only) ----------------
        norms = np.linalg.norm(X, axis=1, keepdims=True)
        norms[norms == 0] = 1
        X = X / norms

        # ---------------- Output directory ----------------
        output_dir = f"{dataset_name}_result/{method}_centering"
        os.makedirs(output_dir, exist_ok=True)

        csv_file = f"{output_dir}/{method}_seed{trial_seed}_10NN.csv"

        # Write header if file doesn't exist
        if not os.path.isfile(csv_file):
            with open(csv_file, mode="w", newline="") as f:
                writer = csv.writer(f)
                writer.writerow([
                    "Dataset", "Method", "Init Time(ms)",
                    "Construction distance cost per point",
                    "NNDescent Time(ms)", "Distance computation",
                    "Total Time(ms)", "NNDescent Iterations",
                    "NNDescent Accuracy", "Number of random vector",
                    "TopK", "iProbe", "top_m", "qProbe", "distance", "centering"
                ])

        # ---------------- Loop over K ----------------
        for k in k_list:

            print(f"\n---- K = {k} ----")

            exact_kNN = np.load(gt_file)[:, 1:k+1]
            topm_list = [25]

            for n_repeats in n_repeats_list:
                for top_m in topm_list:

                    print(f"\n>> n_repeats={n_repeats}, top_m={top_m}")

                    index = FalconnLite.FalconnLite(n, d)
                    index.set_params(
                        n_proj=D, iProbe=iProbe, top_m=top_m,
                        qProbe=qProbe, distance=dist, verbose=verbose,
                        n_threads=n_threads, seed=-1
                    )
                    index.centering = True

                    # ---------------- Init phase ----------------
                    t_init_start = timeit.default_timer()

                    if method == "approx_join":
                        indices, _ = index.approx_join(X, topK=k, n_repeats=n_repeats)

                    t_init_end = timeit.default_timer()
                    init_time_ms = (t_init_end - t_init_start) * 1000

                    # ---------------- NNDescent Refinement ----------------
                    for t in iters:

                        t_nn_start = timeit.default_timer()

                        index1 = NNDescent(
                            X, n_neighbors=k, random_state=trial_seed,
                            tree_init=False, init_graph=indices,
                            metric="cosine", n_iters=t, n_jobs=n_threads
                        )
                        knn_indices, _ = index1.neighbor_graph

                        t_nn_end = timeit.default_timer()
                        nn_time_ms = (t_nn_end - t_nn_start) * 1000
                        total_time_ms = init_time_ms + nn_time_ms


                        acc = utils.getAcc(knn_indices[:,:], exact_kNN[:,:])

                        with open(csv_file, mode="a", newline="") as f:
                            writer = csv.writer(f)
                            writer.writerow([
                                dataset_name,
                                f"{n_repeats} {method}",
                                f"{init_time_ms:.2f}",
                                0,
                                f"{nn_time_ms:.2f}",
                                getattr(index1, "counter", 0) / n,
                                f"{total_time_ms:.2f}",
                                t,
                                f"{acc:.4f}",
                                D,
                                k,
                                iProbe,
                                top_m,
                                qProbe,
                                dist,
                                index.centering
                            ])

        print(f"\n---- Finished {dataset_name} seed={trial_seed} ----\n")

print("\n[LOG] Experiment finished.")
log_file.close()
