import faiss
import argparse
import json
import os
import pandas as pd
import time
from collections import defaultdict

import h5py
import numpy as np
from tqdm import tqdm
import heapq
from pympler import asizeof
from pytrec_eval import RelevanceEvaluator


PROJECT_ROOTPATH = "/rootpath/adaptiveLengthEmbedding"
MODEL_ROOTPATH = ""
DATA_ROOTPATH = ""
OUTPUT_ROOTPATH = ""
GPU_MEMORY_AVAILABLE = 48_000_000_000  # 假设48GB显存
models = []
datasets = []
total_dim = 3584
k = 1000

def process_exp(q_vectors, d_vectors, qids, dids, qrel, dense_dim, sparse_t, is_hybrid):

    def dense_retrieval(q_vectors,d_vectors, dense_dim):
        cost_time = 0
        q_dense = q_vectors[:, :dense_dim]
        d_dense = d_vectors[:, :dense_dim]
        cost_space = d_dense.nbytes

        # 文档向量分片策略
        required_memory = num_docs * dense_dim * 4  # float32占4字节
        slice_threshold = GPU_MEMORY_AVAILABLE * 0.8

        slices = []
        if required_memory > slice_threshold:
            max_docs_per_slice = int(slice_threshold // (dense_dim * 4))
            start_idx = 0
            while start_idx < num_docs:
                end_idx = min(start_idx + max_docs_per_slice, num_docs)
                slices.append((start_idx, end_idx))
                start_idx = end_idx
        else:
            slices = [(0, num_docs)]

        # 分片处理
        D_list, I_list = [], []
        for slice_idx, (start, end) in enumerate(tqdm(slices, desc="building inverted list:")):
            slice_d_vectors = d_dense[start:end].astype('float32')

            # 创建FAISS索引
            index = faiss.IndexFlatIP(dense_dim)
            res = faiss.StandardGpuResources()
            index = faiss.index_cpu_to_gpu(res, 0, index)
            index.add(slice_d_vectors)

            # 搜索近邻
            t1 = time.time()
            D_slice, I_slice = index.search(q_dense.astype('float32'), k)
            # 调整索引到全局编号
            I_slice += start
            D_list.append(D_slice)
            I_list.append(I_slice)
            t2 = time.time()
            cost_time += t2 - t1

        t3 = time.time()
        # 合并结果并排序
        D_all = np.concatenate(D_list, axis=1)
        I_all = np.concatenate(I_list, axis=1)

        # 同步排序
        sorted_indices = np.argsort(-D_all, axis=1)[:, :k]  # 按分数降序排列
        D_dense = np.take_along_axis(D_all, sorted_indices, axis=1)
        I_dense = np.take_along_axis(I_all, sorted_indices, axis=1)
        t4 = time.time()
        cost_time += t4 - t3
        return D_dense, I_dense, cost_time, cost_space

    def sparse_retrieval(q_vectors, d_vectors,dense_dim, sparse_t):
        cost_time = 0
        batch_size = 512_000
        inverted_index = [[] for _ in range(d_vectors.shape[1]-dense_dim)]
        batches = [(i, min(i+batch_size,d_vectors.shape[0])) for i in range(0,d_vectors.shape[0],batch_size)]
        for batch in tqdm(batches, desc="Processing sparse batches"):
            start = batch[0]
            end = batch[1]
            d_sub = d_vectors[start:end]
            d_sparse = d_sub[:, dense_dim:]
            # max_val = np.max(np.abs(d_sub), axis=1, keepdims=True)
            # d_sub = d_vectors / max_val
            d_squared = d_sub ** 2
            d_dense_squared = d_squared[:, :dense_dim]
            d_sparse_squared = d_squared[:, dense_dim:]
            d_sum = np.sum(d_squared, axis=1, keepdims=True)
            d_dense_sum = np.sum(d_dense_squared, axis=1, keepdims=True)
            sparse_required = np.maximum((d_sum * sparse_t - d_dense_sum), 0)
            sparse_index = np.tile(np.arange(d_sparse_squared.shape[1]), (d_sub.shape[0], 1))
            sorted_idx = np.argsort(-d_sparse_squared, axis=1)
            d_sparse_squared_sorted = np.take_along_axis(d_sparse_squared, sorted_idx, axis=1)
            d_sparse_sorted = np.take_along_axis(d_sparse, sorted_idx, axis=1)
            d_index_sorted = np.take_along_axis(sparse_index, sorted_idx, axis=1)
            d_sparse_sorted_cum = np.cumsum(d_sparse_squared_sorted, axis=1)
            required_num = np.argmax(d_sparse_sorted_cum > sparse_required, axis=1)
            for i in range(d_sparse.shape[0]):
                for j in range(required_num[i]):
                    inverted_index[d_index_sorted[i, j]].append((i, d_sparse_sorted[i, j]))
            del d_sparse, d_sparse_squared, d_sparse_sorted, d_index_sorted

        cost_space = asizeof.asizeof(inverted_index)
        #得到稀疏的q
        queries = []
        q_sparse = q_vectors[:, dense_dim:]
        q_squared = q_vectors ** 2
        q_dense_squared = q_squared[:, :dense_dim]
        q_sparse_squared = q_squared[:, dense_dim:]
        q_sum = np.sum(q_squared, axis=1, keepdims=True)
        q_dense_sum = np.sum(q_dense_squared, axis=1, keepdims=True)
        q_sparse_required = np.maximum((q_sum * sparse_t - q_dense_sum), 0)
        q_sparse_index = np.tile(np.arange(q_sparse_squared.shape[1]), (q_vectors.shape[0], 1))
        sorted_idx = np.argsort(-q_sparse_squared, axis=1)
        q_sparse_squared_sorted = np.take_along_axis(q_sparse_squared, sorted_idx, axis=1)
        q_sparse_sorted = np.take_along_axis(q_sparse, sorted_idx, axis=1)
        q_index_sorted = np.take_along_axis(q_sparse_index, sorted_idx, axis=1)
        q_sparse_sorted_cum = np.cumsum(q_sparse_squared_sorted, axis=1)
        required_num = np.argmax(q_sparse_sorted_cum > q_sparse_required, axis=1)
        for i in range(q_vectors.shape[0]):
            queries.append([])
            for j in range(required_num[i]):
                queries[i].append((q_index_sorted[i, j], q_sparse_sorted[i, j]))

        #检索
        t1 = time.time()
        D_all, I_all = [],[]
        for query in tqdm(queries, desc="sparse retrieving"):
            score = defaultdict(float)
            for dim, qval in query:
                for doc_idx, dval in inverted_index[dim]:
                    score[doc_idx] += qval * dval
            top1000 = heapq.nlargest(1000, score.items(), key=lambda x:x[1])
            I = [doc_idx for doc_idx, val in top1000]
            D = [val for doc_idx, val in top1000]
            I = I + [-1] * (k - len(I))
            D = D + [-9999] * (k - len(D))
            D_all.append(D)
            I_all.append(I)

        t2 = time.time()
        cost_time += t2 - t1

        D_all = np.array(D_all)
        I_all = np.array(I_all).astype(int)
        return D_all, I_all, cost_time, cost_space

    def get_hybrid_result(D_dense, I_dense, D_sparse, I_sparse):
        n_query = D_dense.shape[0]
        docids = np.concatenate((I_dense, I_sparse), axis=1)
        scores = np.concatenate((D_dense, D_sparse), axis=1)
        qids = np.repeat(np.arange(n_query), 2*k)
        docids_flat = docids.reshape(-1)
        scores_flat = scores.reshape(-1)
        df = pd.DataFrame({"qid": qids, "docid": docids_flat, "score": scores_flat})
        df = df.groupby(['qid','docid'],as_index=False)['score'].sum()
        topk = df.groupby('qid',group_keys=False).apply(lambda x:x.nlargest(k,'score'))
        result = topk.groupby('qid').apply(lambda x:x[['score','docid']].to_numpy())
        D_hybrid = np.vstack([r[:,0] for r in result])
        I_hybrid = np.vstack([r[:,1] for r in result]).astype(int)
        return D_hybrid, I_hybrid

    D_dense, I_dense, cost_time_dense, cost_space_dense = dense_retrieval(q_vectors,d_vectors, dense_dim)

    if is_hybrid:
        D_sparse, I_sparse, cost_time_sparse, cost_space_sparse = sparse_retrieval(q_vectors,d_vectors, dense_dim, sparse_t)

    if is_hybrid:
        D_hybrid, I_hybrid = get_hybrid_result(D_dense, I_dense, D_sparse, I_sparse)
    else:
        D_hybrid, I_hybrid = D_dense, I_dense

    # 构建run结构
    run = {}
    for q_idx in range(len(qids)):
        qid = qids[q_idx]
        doc_indices = I_hybrid[q_idx]
        scores = D_hybrid[q_idx]

        # 映射到文档ID
        run[qid] = {}
        for idx, doc_idx in enumerate(doc_indices):
            #跳过稀疏检索时填充的部分
            if doc_idx == -1:
                continue
            run[qid][dids[doc_idx]] = float(scores[idx])

    evaluator = RelevanceEvaluator(qrel, {'ndcg_cut_10', 'recall_1000'})
    results = evaluator.evaluate(run)
    ndcg_scores = [v['ndcg_cut_10'] for v in results.values()]
    recall_scores = [v['recall_1000'] for v in results.values()]
    if not is_hybrid:
        cost_time_sparse = 0
        cost_space_sparse = 0
    result = {
        "ndcg": float(np.mean(ndcg_scores)),
        "recall": float(np.mean(recall_scores)),
        "dense_cost_time": cost_time_dense,
        "dense_cost_space": cost_space_dense,
        "sparse_cost_time": cost_time_sparse,
        "sparse_cost_space": cost_space_sparse,
        "hybrid_cost_time": cost_time_dense + cost_time_sparse,
        "hybrid_cost_space": cost_space_dense + cost_space_sparse
    }

    return result


if __name__ == "__main__":
    query_num = -1
    config_path = os.path.join(PROJECT_ROOTPATH, "configs", "transform_config.json")
    with open(config_path) as f:
        config = json.load(f)
        MODEL_ROOTPATH = config["model_rootpath"]
        DATA_ROOTPATH = config["data_rootpath"]
        OUTPUT_ROOTPATH = config["output_rootpath"]
        # model_name = "gte-Qwen2-7B-instruct"
        # dataset_name = "nq"

    with open(os.path.join(PROJECT_ROOTPATH, "configs", "hybrid_retrieval_config.json")) as f:
        config = json.load(f)
        model_name = config["model"]
        dataset_name = config["dataset"]
        exps = config["exps"]
        is_hybrid = config["mode"] == "hybrid"
    #读取向量数据
    exp_name = f"{model_name}_{dataset_name}_exp"
    q_vector_path = os.path.join(OUTPUT_ROOTPATH, exp_name, "transform/q_vectors.h5")
    d_vector_path = os.path.join(OUTPUT_ROOTPATH, exp_name, "transform/d_vectors.h5")
    with h5py.File(q_vector_path, "r") as qhf, h5py.File(d_vector_path, "r") as dhf:
        qids = [qid.decode() for qid in qhf["ids"][:]]
        dids = [did.decode() for did in dhf["ids"][:]]
        qids = np.array(qids)
        dids = np.array(dids)
        num_docs = dids.shape[0]

        # 计算真实相关性字典
        qrel = {}
        qrel_path = os.path.join(DATA_ROOTPATH, dataset_name, "qrels/test.tsv")
        with open(qrel_path) as f:
            for line_idx, line in enumerate(f):
                if line_idx == 0: continue  # 跳过header
                qid, docid, score = line.strip().split("\t")
                qrel.setdefault(qid, {})[docid] = int(score)

        if query_num != -1:
            qrel = dict(list(qrel.items())[:query_num])

        total_dim = qhf["vectors"].shape[1]
        q_vectors_full = np.empty((qhf["vectors"].shape[0], qhf["vectors"].shape[1]), dtype="float32")
        d_vectors_full = np.empty((dhf["vectors"].shape[0], dhf["vectors"].shape[1]), dtype="float32")
        qhf["vectors"].read_direct(q_vectors_full)
        dhf["vectors"].read_direct(d_vectors_full)
        # 只筛选test set中的q
        valid_q_mask = np.array([qid in qrel for qid in qids])
        valid_q_indices = np.where(valid_q_mask)[0]
        qids = qids[valid_q_mask]
        q_vectors_full = q_vectors_full[valid_q_indices]

    results = []
    for exp in exps:
        dense_dim = exp["dense_dim"]
        sparse_t = exp["sparse_threshold"]
        exp_result = process_exp(q_vectors_full, d_vectors_full, qids, dids, qrel, dense_dim, sparse_t,is_hybrid)
        print(exp)
        print(exp_result)
        results.append({
            "model": model_name,
            "dataset": dataset_name,
            "dense_dim": dense_dim,
            "sparse_t": sparse_t,
            "result": exp_result
        })
    with open(os.path.join(PROJECT_ROOTPATH,"result",f"hybrid_retrieval_result_{time.strftime('%Y%m%d%H%M%S')}.json"), "w") as f:
        json.dump(results, f, indent=2)
    print("done")

