import pickle

import utils
import numpy as np
import timeit
from pynndescent import NNDescent
import FalconnLite
import csv
import os

if __name__ == '__main__':
    # ---------------- Select Method ----------------
    # options: "bucket_sampling", "coll_counting", "dist_estimating", "approx_kNN"
    method = "bucket_sampling"
    # method = "coll_counting"
    # method = "dist_estimating"
    # method = "approx_kNN"

    dataset_name = "Mnist60K"
    # dataset_name = "Sift1M"
    # dataset_name = "Gist1M"
    # ---------------- Dataset Setup ----------------


    # 选择数据集
    data_info = dataset_paths[dataset_name]
    bin_file = data_info["bin_file"]
    n = data_info["n"]
    d = data_info["d"]
    gt_file = data_info["gt_file"]

    # ---------------- Load Dataset ----------------
    X = utils.mmap_bin(bin_file, n, d)

    # Cosine normalization
    norms = np.linalg.norm(X, axis=1, keepdims=True)
    norms[norms == 0] = 1
    X /= norms

    # ---------------- Ground Truth ----------------
    exact_kNN = np.load(gt_file)



    k = 50
    exact_kNN = exact_kNN[:, 1:k+1]  # remove self neighbor

    # ---------------- FalconnLite Setup ----------------
    D = 2**8
    iProbe = 5
    top_m = k
    qProbe = 5
    dist = "Cosine"
    verbose = True
    seed = 42
    n_repeats = 4
    n_threads = 32

    index = FalconnLite.FalconnLite(n, d)
    index.set_params(n_proj=D, iProbe=iProbe, top_m=top_m,
                     qProbe=qProbe, distance=dist, verbose=verbose,
                     n_threads=n_threads, seed=seed)
    index.centering = True  # 全部固定 True
    # index.centering = False
    # ---------------- Output CSV Setup ----------------
    csv_file = "topm=k.csv"
    # csv_file = "FalconnLite_NNDescent_results.csv"
    file_exists = os.path.isfile(csv_file)
    if not file_exists:
        with open(csv_file, mode='w', newline='') as f:
            writer = csv.writer(f)
            writer.writerow([
                "Dataset", "Method", "Init Time(ms)", "NNDescent Time(ms)",
                "Total Time(ms)", "NNDescent Iterations", "NNDescent Accuracy", "TopK"
            ])

    # ---------------- FalconnLite Method Execution ----------------
    t_init_start = timeit.default_timer()

    if method == "bucket_sampling":
        indices = index.bucket_sampling(X, topK=k)
    elif method == "coll_counting":
        indices = index.coll_counting(X, topK=k, n_repeats=n_repeats)
    elif method == "dist_estimating":
        indices = index.dist_estimating(X, topK=k, n_repeats=n_repeats)
    elif method == "approx_kNN":
        indices, distances = index.approx_kNN(X, topK=k, n_repeats=n_repeats)
    else:
        raise ValueError(f"Unknown method: {method}")

    t_init_end = timeit.default_timer()
    init_time_ms = (t_init_end - t_init_start) * 1000
    print(f"{method}: Init Time={init_time_ms:.2f}ms, Accuracy={utils.getAcc(exact_kNN, indices):.4f}")

    # ---------------- NNDescent Test ----------------
    iters = [0,1,2,3]  # 可自定义迭代次数
    seed = 42

    for t in iters:
        # Warmup NNDescent
        NNDescent(X, n_neighbors=k, random_state=seed, tree_init=False,
                  metric="cosine", n_iters=1, n_jobs=n_threads)

        t_nn_start = timeit.default_timer()
        knn_indices, _ = NNDescent(
            X, n_neighbors=k, random_state=seed, tree_init=False,
            init_graph=indices,
            metric="cosine", n_iters=t, n_jobs=n_threads
        ).neighbor_graph
        t_nn_end = timeit.default_timer()
        nn_time_ms = (t_nn_end - t_nn_start) * 1000
        total_time_ms = init_time_ms + nn_time_ms
        acc = utils.getAcc(knn_indices, exact_kNN)

        print(f"NNDescent | Iterations={t} | recall@{k}={acc:.4f} | NNDescent Time={nn_time_ms:.2f}ms | Total Time={total_time_ms:.2f}ms")

        # ---------------- Write CSV ----------------
        with open(csv_file, mode='a', newline='') as f:
            writer = csv.writer(f)
            writer.writerow([
                dataset_name,
                method,
                f"{init_time_ms:.2f}",
                f"{nn_time_ms:.2f}",
                f"{total_time_ms:.2f}",
                t,
                f"{acc:.4f}",
                k
            ])
