import pickle
import utils
import numpy as np
import timeit
from pynndescent import NNDescent
import FalconnLite
import csv
import os

if __name__ == '__main__':
    # ---------------- Select Method ----------------
    # options: "bucket_sampling", "coll_counting", "dist_estimating", "approx_kNN"
    # method = "coll_counting"
    method = "approx_kNN"
    # method = "approx_join"
    # ---------------- Dataset Setup ----------------
    dataset_name = "landmark"


    DATA_DIR = "/path/to/dataset"
    GT_DIR   = "/path/to/groundtruth"

    dataset_paths = {
        "landmark": {
            "bin_file": f"{DATA_DIR}/landmark-dino-768-cosine.bin",
            "gt_file": f"{GT_DIR}/landmark_Cosine_k_100_indices.npy",
        }
    }


    data_info = dataset_paths[dataset_name]
    bin_file = data_info["bin_file"]
    n = data_info["n"]
    d = data_info["d"]
    gt_file = data_info["gt_file"]

    # ---------------- Load Dataset ----------------
    X = utils.mmap_bin(bin_file, n, d)

    # Cosine normalization
    norms = np.linalg.norm(X, axis=1, keepdims=True)
    norms[norms == 0] = 1
    X /= norms

    # ---------------- Ground Truth ----------------
    exact_kNN = np.load(gt_file)
    k = 20
    exact_kNN = exact_kNN[:, 1:k+1]  # remove self neighbor

    # ---------------- FalconnLite Setup ----------------
    D = 2**8
    iProbe = 3
    top_m = 25
    qProbe = 9
    dist = "Cosine"
    verbose = True
    seed = -1
    n_repeats = 4
    n_threads = 32

    ##------------------Parameters setting topm={k,2k,5k},qprobe={3,5,10},iprobe={3,5,10}
    # top_m=k
    # #
    # top_m=2*k
    # #
    # top_m=5*k
    # #
    # #
    # top_m=2*k
    # iProbe=5
    #
    # iProbe=10
    # # # #
    # iProbe=3
    # qProbe=5
    # # #
    # qProbe=10
    ##---------------------------


    index = FalconnLite.FalconnLite(n, d)
    index.set_params(n_proj=D, iProbe=iProbe, top_m=top_m,
                     qProbe=qProbe, distance=dist, verbose=verbose,
                     n_threads=n_threads, seed=seed)
    index.centering = True  # 固定 True

    # ---------------- Output CSV Setup ----------------
    csv_file = f"landmark_result/{dataset_name}_{method}_results1.csv"
    if not os.path.isfile(csv_file):
        with open(csv_file, mode='w', newline='') as f:
            writer = csv.writer(f)
            writer.writerow([
                "Dataset", "Method", "Init Time(ms)","Construction distance cost per point", "NNDescent Time(ms)","Distance computation",
                "Total Time(ms)", "NNDescent Iterations", "NNDescent Accuracy",
                "TopK", "iProbe", "top_m", "qProbe", "distance", "centering"
            ])

    # ---------------- FalconnLite Method Execution ----------------
    t_init_start = timeit.default_timer()

    if method == "bucket_sampling":
        indices = index.bucket_sampling(X, topK=k)
    elif method == "coll_counting":
        indices = index.coll_counting(X, topK=k, n_repeats=n_repeats)
    elif method == "dist_estimating":
        indices = index.dist_estimating(X, topK=k, n_repeats=n_repeats)
    elif method == "approx_kNN":
        indices, distances = index.approx_kNN(X, topK=k, n_repeats=n_repeats)
    elif method == "approx_join":
        indices, distances = index.approx_join(X, topK=k, n_repeats=n_repeats)
    else:
        raise ValueError(f"Unknown method: {method}")

    t_init_end = timeit.default_timer()
    init_time_ms = (t_init_end - t_init_start) * 1000
    print(f"{method}: Init Time={init_time_ms:.2f}ms, Accuracy={utils.getAcc(exact_kNN, indices):.4f}")

    # ---------------- NNDescent Test ----------------
    iters = [10]  # 可自定义迭代次数
    seed = 42

      # 累计 NNDescent 时间（ms）
    # print("------------------ Warm up -------------------")
    # NNDescent(X, n_neighbors=k, random_state=seed, tree_init=False,
    #           metric="cosine", n_iters=1, n_jobs=n_threads)
    # print("---------------- Warm up End------------------")

    for t in iters:
        t_nn_start = timeit.default_timer()
        index1 = NNDescent(X, n_neighbors=k, random_state=seed, tree_init=False,
                           init_graph=indices, metric="cosine", n_iters=t, n_jobs=n_threads)
        knn_indices, _ = index1.neighbor_graph
        print(index1.neighbor_graph[0][0])
        t_nn_end = timeit.default_timer()

        nn_time_ms = (t_nn_end - t_nn_start) * 1000.0
        total_time_ms = nn_time_ms+init_time_ms


        acc = utils.getAcc(knn_indices, exact_kNN)

        print(f"NNDescent | Iterations={t} | recall@{k}={acc:.4f} | Inital_time={init_time_ms:.2f}ms |"
              f"iter_time={nn_time_ms:.2f}ms | Total={total_time_ms:.2f}ms")

        # ---------------- Write CSV ----------------
        with open(csv_file, mode='a', newline='') as f:
            writer = csv.writer(f)
            writer.writerow([
                dataset_name,
                f"{n_repeats} {method}",
                f"{init_time_ms:.2f}",
                0,
                f"{nn_time_ms:.2f}",             # NNDescent Time -
                index1.counter,
                f"{total_time_ms:.2f}",                # Total Time (ms) = init + cumulative NNDescent
                t,                                     # NNDescent Iterations (current t)
                f"{acc:.4f}",
                k,
                iProbe,
                top_m,
                qProbe,
                dist,
                index.centering
            ])
