import json
import os
import h5py
import numpy as np
from tqdm import tqdm
import faiss

PROJECT_ROOTPATH = "/rootpath/adaptiveLengthEmbedding"
MODEL_ROOTPATH = ""
DATA_ROOTPATH = ""
OUTPUT_ROOTPATH = ""
GPU_MEMORY_AVAILABLE = 48_000_000_000  # 48GB显存
models = ["gte-base-en-v1.5", "gte-Qwen2-7B-instruct"]
datasets = []


def process_model(model, dataset):
    """处理单个模型，返回该模型的top10文档字典{qid: set(docids)}"""
    print(f"\nProcessing {model} on {dataset}")
    dim = 768 if "base" in model else 3584

    # 加载向量文件
    exp_name = f"{model}_{dataset}_exp"
    q_path = os.path.join(OUTPUT_ROOTPATH, exp_name, "transform/q_vectors.h5")
    d_path = os.path.join(OUTPUT_ROOTPATH, exp_name, "transform/d_vectors.h5")

    model_top10 = {}
    with h5py.File(q_path, "r") as qhf, h5py.File(d_path, "r") as dhf:
        # 处理编码
        qids = [qid.decode() for qid in qhf["ids"][:]]
        dids = [did.decode() for did in dhf["ids"][:]]
        qids = np.array(qids)
        dids = np.array(dids)

        # 筛选有效查询（只在qrel中出现的）
        qrel_path = os.path.join(DATA_ROOTPATH, dataset, "qrels/test.tsv")
        qrel = {}
        with open(qrel_path) as f:
            next(f)
            for line in f:
                qid, docid, score = line.strip().split("\t")
                if int(score) > 0:
                    qrel.setdefault(qid, set()).add(docid)
        valid_q_mask = np.array([qid in qrel for qid in qids])
        valid_qids = qids[valid_q_mask]

        q_vectors = qhf["vectors"][valid_q_mask, :dim].astype("float32")
        d_vectors = dhf["vectors"][:, :dim].astype("float32")

        # 分片处理逻辑
        num_docs = len(d_vectors)
        slice_size = int(GPU_MEMORY_AVAILABLE * 0.4 / (dim * 4))
        slices = [(i, min(i + slice_size, num_docs)) for i in range(0, num_docs, slice_size)]

        merged_scores = np.zeros((len(q_vectors), 0), dtype=np.float32)
        merged_indices = np.zeros((len(q_vectors), 0), dtype=np.int64)

        for start, end in tqdm(slices, desc="Processing slices"):
            slice_docs = d_vectors[start:end]

            # 构建FAISS索引
            index = faiss.IndexFlatIP(dim)
            res = faiss.StandardGpuResources()
            gpu_index = faiss.index_cpu_to_gpu(res, 0, index)
            gpu_index.add(slice_docs)

            # 搜索并合并结果
            scores, indices = gpu_index.search(q_vectors, 10)
            indices += start
            merged_scores = np.hstack([merged_scores, scores])
            merged_indices = np.hstack([merged_indices, indices])

            # 保留全局top1000
            topk = min(10, merged_scores.shape[1])
            top_indices = np.argsort(-merged_scores, axis=1)[:, :topk]
            merged_scores = np.take_along_axis(merged_scores, top_indices, axis=1)
            merged_indices = np.take_along_axis(merged_indices, top_indices, axis=1)

        # 记录top10结果
        top10_indices = merged_indices[:, :10]
        for q_idx, qid in enumerate(valid_qids):
            docids = dids[top10_indices[q_idx]].tolist()
            model_top10[qid] = set(docids)

    return model_top10


def main():
    introf2_results = {}

    # 加载配置
    config_path = os.path.join(PROJECT_ROOTPATH, "configs", "transform_config.json")
    with open(config_path) as f:
        config = json.load(f)
        global datasets, MODEL_ROOTPATH, DATA_ROOTPATH, OUTPUT_ROOTPATH
        MODEL_ROOTPATH = config["model_rootpath"]
        DATA_ROOTPATH = config["data_rootpath"]
        OUTPUT_ROOTPATH = config["output_rootpath"]
        datasets = config["datasets"]

    for dataset in datasets:
        print(f"\nProcessing dataset: {dataset}")

        # 加载所有相关文档对
        qrel_pairs = []
        qrel_path = os.path.join(DATA_ROOTPATH, dataset, "qrels/test.tsv")
        with open(qrel_path) as f:
            next(f)  # skip header
            for line in f:
                qid, docid, score = line.strip().split("\t")
                if int(score) > 0:
                    qrel_pairs.append((qid, docid))

        # 处理两个模型
        base_top10 = process_model("gte-base-en-v1.5", dataset)
        qwen_top10 = process_model("gte-Qwen2-7B-instruct", dataset)

        # 统计四类文档数量
        stats = {
            "both_correct": 0,
            "gte-base_only": 0,
            "gte-Qwen2_only": 0,
            "neither": 0,
            "total_docs": len(qrel_pairs)
        }

        for qid, docid in qrel_pairs:
            in_base = docid in base_top10.get(qid, set())
            in_qwen = docid in qwen_top10.get(qid, set())

            if in_base and in_qwen:
                stats["both_correct"] += 1
            elif in_base and not in_qwen:
                stats["gte-base_only"] += 1
            elif not in_base and in_qwen:
                stats["gte-Qwen2_only"] += 1
            else:
                stats["neither"] += 1

        introf2_results[dataset] = stats
        print(f"Results for {dataset}:")
        print(json.dumps(stats, indent=2))

    # 保存结果
    output_dir = os.path.join(PROJECT_ROOTPATH, "result")
    os.makedirs(output_dir, exist_ok=True)
    with open(os.path.join(output_dir, "introf2_d.json"), "w") as f:
        json.dump(introf2_results, f, indent=2)


if __name__ == "__main__":
    main()
