import argparse
import json
import os
import h5py
import numpy as np
from copy import deepcopy
from tqdm import tqdm
import faiss
from pytrec_eval import RelevanceEvaluator

PROJECT_ROOTPATH = "/rootpath/adaptiveLengthEmbedding"
MODEL_ROOTPATH = ""
DATA_ROOTPATH = ""
OUTPUT_ROOTPATH = ""
GPU_MEMORY_AVAILABLE = 48_000_000_000  # 48GB显存
models = ["gte-Qwen2-7B-instruct"]
datasets = ['msmarco']
dims = [3584, 3096, 2560, 2048, 1792, 1536, 1024, 768, 512, 384, 256, 128, 64, 32]
N_TRIALS = 10  # 实验重复次数
max_dim = 3584

def main():
    np.random.seed(1234)
    results = {}

    for model_name in models:
        for dataset_name in datasets:
            exp_name = f"{model_name}_{dataset_name}_exp"
            q_vector_path = os.path.join(OUTPUT_ROOTPATH, exp_name, "transform/q_vectors.h5")
            d_vector_path = os.path.join(OUTPUT_ROOTPATH, exp_name, "transform/d_vectors.h5")

            with h5py.File(q_vector_path, "r") as qhf, h5py.File(d_vector_path, "r") as dhf:
                qids = [qid.decode() for qid in qhf["ids"][:]]
                dids = [did.decode() for did in dhf["ids"][:]]
                qids = np.array(qids)
                dids = np.array(dids)

                qrel = {}
                qrel_path = os.path.join(DATA_ROOTPATH, dataset_name, "qrels/test.tsv")
                with open(qrel_path) as f:
                    for line_idx, line in enumerate(f):
                        if line_idx == 0: continue
                        qid, docid, score = line.strip().split("\t")
                        qrel.setdefault(qid, {})[docid] = int(score)

                # 读取完整向量到内存
                q_vectors_full = qhf["vectors"][:].astype("float32")
                d_vectors_full = dhf["vectors"][:].astype("float32")
                full_dim = q_vectors_full.shape[1]

                # 筛选有效查询
                valid_q_mask = np.array([qid in qrel for qid in qids])
                valid_q_indices = np.where(valid_q_mask)[0]
                valid_qids = qids[valid_q_mask]
                q_vectors_subset_full = q_vectors_full[valid_q_indices]

                # 初始化结果存储
                if dataset_name not in results:
                    results[dataset_name] = {dim: [] for dim in dims if dim <= full_dim}

                for trial in range(N_TRIALS):
                    print(f"\nTrial {trial + 1}/{N_TRIALS} for {dataset_name}")

                    dim_permutation = np.random.permutation(full_dim)

                    # 预计算所有维度切片
                    q_vectors_subset_full = q_vectors_subset_full[:, dim_permutation[:max_dim]]
                    d_vectors_full = d_vectors_full[:, dim_permutation[:max_dim]]

                    # 文档分片
                    for dim in dims:
                        print(f"Processing dim {dim}...")
                        num_docs = d_vectors_full.shape[0]
                        slice_threshold = GPU_MEMORY_AVAILABLE * 0.4
                        max_docs_per_slice = int(slice_threshold // (dim * 4))
                        slices = [(s, min(s + max_docs_per_slice, num_docs))
                                  for s in range(0, num_docs, max_docs_per_slice)]

                        q_vectors = q_vectors_subset_full[:, :dim]
                        d_vectors = d_vectors_full[:, :dim]

                        # 构建分片索引
                        D_list, I_list = [], []
                        for start, end in slices:
                            slice_d = d_vectors[start:end]
                            index = faiss.IndexFlatIP(dim)
                            res = faiss.StandardGpuResources()
                            index = faiss.index_cpu_to_gpu(res, 0, index)
                            index.add(slice_d)
                            D_slice, I_slice = index.search(q_vectors, 10)
                            I_slice += start  # 调整文档索引偏移量
                            D_list.append(D_slice)
                            I_list.append(I_slice)

                        # 合并排序结果
                        D_all = np.concatenate(D_list, axis=1)
                        I_all = np.concatenate(I_list, axis=1)
                        sorted_indices = np.argsort(-D_all, axis=1)[:, :10]
                        D_final = np.take_along_axis(D_all, sorted_indices, axis=1)
                        I_final = np.take_along_axis(I_all, sorted_indices, axis=1)

                        # 构建评估结果
                        run = {qid: {dids[doc_idx]: float(score) for doc_idx, score in zip(I_final[i], D_final[i])}
                               for i, qid in enumerate(valid_qids)}

                        # 计算评估指标
                        evaluator = RelevanceEvaluator(qrel, {'ndcg_cut_10'})
                        trial_results = evaluator.evaluate(run)
                        ndcg_scores = [v['ndcg_cut_10'] for v in trial_results.values()]
                        mean_ndcg = np.mean(ndcg_scores)
                        results[dataset_name][dim].append(round(mean_ndcg, 4))
                        print(f"Dim {dim} NDCG@10: {mean_ndcg:.4f}")

    # 保存结果
    output_path = os.path.join(PROJECT_ROOTPATH, "result/introf1_optimized.json")
    os.makedirs(os.path.dirname(output_path), exist_ok=True)
    with open(output_path, 'w') as f:
        json.dump(results, f, indent=2)


if __name__ == "__main__":
    config_path = os.path.join(PROJECT_ROOTPATH, "configs", "transform_config.json")
    with open(config_path) as f:
        config = json.load(f)
        MODEL_ROOTPATH = config["model_rootpath"]
        DATA_ROOTPATH = config["data_rootpath"]
        OUTPUT_ROOTPATH = config["output_rootpath"]
        # datasets = config["datasets"]
    main()
