#稠密检索时逐条检索而不是批量检索
#空间仍是总空间，时间改为每条查询的平均时间
#检索一个查询子集以加快效率
import random

import faiss
import argparse
import json
import os
import pandas as pd
import time
from collections import defaultdict

import h5py
import numpy as np
from tqdm import tqdm
import heapq
from pympler import asizeof
from pytrec_eval import RelevanceEvaluator
import csv


PROJECT_ROOTPATH = "/rootpath/adaptiveLengthEmbedding"
MODEL_ROOTPATH = ""
DATA_ROOTPATH = ""
OUTPUT_ROOTPATH = ""
GPU_MEMORY_AVAILABLE = 192_000_000_000  # 假设48GB显存
models = []
datasets = []
rand_seed = 1234
random.seed(rand_seed)
total_dim = 3584
k = 1000

def process_exp(q_vectors, d_vectors, qids, dids, qrel, dense_dim, sparse_t, is_hybrid):

    def dense_retrieval(q_vectors,d_vectors, dense_dim):
        cost_time = 0
        q_dense = q_vectors[:, :dense_dim]
        d_dense = d_vectors[:, :dense_dim]
        cost_space = d_dense.nbytes

        # 文档向量分片策略
        required_memory = num_docs * dense_dim * 4  # float32占4字节
        slice_threshold = GPU_MEMORY_AVAILABLE * 0.8

        slices = []
        if required_memory > slice_threshold:
            max_docs_per_slice = int(slice_threshold // (dense_dim * 4))
            start_idx = 0
            while start_idx < num_docs:
                end_idx = min(start_idx + max_docs_per_slice, num_docs)
                slices.append((start_idx, end_idx))
                start_idx = end_idx
        else:
            slices = [(0, num_docs)]

        # 分片处理
        D_list, I_list = [], []
        for slice_idx, (start, end) in enumerate(tqdm(slices, desc="dense retrieval:")):
            slice_d_vectors = d_dense[start:end].astype('float32')

            # 创建FAISS索引
            index = faiss.IndexFlatIP(dense_dim)
            index.add(slice_d_vectors)

            # 搜索近邻
            t1 = time.time()
            num_query = q_dense.shape[0]
            D_each_q = []
            I_each_q = []
            for i in range(num_query):
                D, I = index.search(q_dense[i:i+1].astype('float32'), k)
                D_each_q.append(D)
                I_each_q.append(I)
            D_slice = np.vstack(D_each_q)
            I_slice = np.vstack(I_each_q)
            # D_slice, I_slice = index.search(q_dense.astype('float32'), k)
            # 调整索引到全局编号
            I_slice += start
            D_list.append(D_slice)
            I_list.append(I_slice)
            t2 = time.time()
            cost_time += t2 - t1

        t3 = time.time()
        # 合并结果并排序
        D_all = np.concatenate(D_list, axis=1)
        I_all = np.concatenate(I_list, axis=1)

        # 同步排序
        sorted_indices = np.argsort(-D_all, axis=1)[:, :k]  # 按分数降序排列
        D_dense = np.take_along_axis(D_all, sorted_indices, axis=1)
        I_dense = np.take_along_axis(I_all, sorted_indices, axis=1)
        t4 = time.time()
        cost_time += t4 - t3
        cost_time /= q_vectors.shape[0]
        return D_dense, I_dense, cost_time, cost_space

    def sparse_retrieval(q_vectors, d_vectors,dense_dim, sparse_t):
        cost_time = 0
        batch_size = 512_000
        inverted_index = [[] for _ in range(d_vectors.shape[1]-dense_dim)]
        batches = [(i, min(i+batch_size,d_vectors.shape[0])) for i in range(0,d_vectors.shape[0],batch_size)]
        for batch in tqdm(batches, desc="building inverted list"):
            start = batch[0]
            end = batch[1]
            d_sub = d_vectors[start:end]
            d_sparse = d_sub[:, dense_dim:]
            # max_val = np.max(np.abs(d_sub), axis=1, keepdims=True)
            # d_sub = d_vectors / max_val
            d_squared = d_sub ** 2
            d_dense_squared = d_squared[:, :dense_dim]
            d_sparse_squared = d_squared[:, dense_dim:]
            d_sum = np.sum(d_squared, axis=1, keepdims=True)
            d_dense_sum = np.sum(d_dense_squared, axis=1, keepdims=True)
            sparse_required = np.maximum((d_sum * sparse_t - d_dense_sum), 0)
            sparse_index = np.tile(np.arange(d_sparse_squared.shape[1]), (d_sub.shape[0], 1))
            sorted_idx = np.argsort(-d_sparse_squared, axis=1)
            d_sparse_squared_sorted = np.take_along_axis(d_sparse_squared, sorted_idx, axis=1)
            d_sparse_sorted = np.take_along_axis(d_sparse, sorted_idx, axis=1)
            d_index_sorted = np.take_along_axis(sparse_index, sorted_idx, axis=1)
            d_sparse_sorted_cum = np.cumsum(d_sparse_squared_sorted, axis=1)
            required_num = np.argmax(d_sparse_sorted_cum > sparse_required, axis=1)
            for i in range(d_sparse.shape[0]):
                for j in range(required_num[i]):
                    inverted_index[d_index_sorted[i, j]].append((i, d_sparse_sorted[i, j]))
            del d_sparse, d_sparse_squared, d_sparse_sorted, d_index_sorted

        # cost_space = asizeof.asizeof(inverted_index)
        cost_space = -1
        #得到稀疏的q
        queries = []
        q_sparse = q_vectors[:, dense_dim:]
        q_squared = q_vectors ** 2
        q_dense_squared = q_squared[:, :dense_dim]
        q_sparse_squared = q_squared[:, dense_dim:]
        q_sum = np.sum(q_squared, axis=1, keepdims=True)
        q_dense_sum = np.sum(q_dense_squared, axis=1, keepdims=True)
        q_sparse_required = np.maximum((q_sum * sparse_t - q_dense_sum), 0)
        q_sparse_index = np.tile(np.arange(q_sparse_squared.shape[1]), (q_vectors.shape[0], 1))
        sorted_idx = np.argsort(-q_sparse_squared, axis=1)
        q_sparse_squared_sorted = np.take_along_axis(q_sparse_squared, sorted_idx, axis=1)
        q_sparse_sorted = np.take_along_axis(q_sparse, sorted_idx, axis=1)
        q_index_sorted = np.take_along_axis(q_sparse_index, sorted_idx, axis=1)
        q_sparse_sorted_cum = np.cumsum(q_sparse_squared_sorted, axis=1)
        required_num = np.argmax(q_sparse_sorted_cum > q_sparse_required, axis=1)
        for i in range(q_vectors.shape[0]):
            queries.append([])
            for j in range(required_num[i]):
                queries[i].append((q_index_sorted[i, j], q_sparse_sorted[i, j]))

        #检索
        t1 = time.time()
        D_all, I_all = [],[]
        for query in tqdm(queries, desc="sparse retrieving"):
            score = defaultdict(float)
            for dim, qval in query:
                for doc_idx, dval in inverted_index[dim]:
                    score[doc_idx] += qval * dval
            top1000 = heapq.nlargest(1000, score.items(), key=lambda x:x[1])
            I = [doc_idx for doc_idx, val in top1000]
            D = [val for doc_idx, val in top1000]
            I = I + [-1] * (k - len(I))
            D = D + [-9999] * (k - len(D))
            D_all.append(D)
            I_all.append(I)

        t2 = time.time()
        cost_time += t2 - t1

        D_all = np.array(D_all)
        I_all = np.array(I_all).astype(int)
        cost_time /= q_vectors.shape[0]
        return D_all, I_all, cost_time, cost_space

    def get_hybrid_result(D_dense, I_dense, D_sparse, I_sparse):
        n_query = D_dense.shape[0]
        docids = np.concatenate((I_dense, I_sparse), axis=1)
        scores = np.concatenate((D_dense, D_sparse), axis=1)
        qids = np.repeat(np.arange(n_query), 2*k)
        docids_flat = docids.reshape(-1)
        scores_flat = scores.reshape(-1)
        df = pd.DataFrame({"qid": qids, "docid": docids_flat, "score": scores_flat})
        df = df.groupby(['qid','docid'],as_index=False)['score'].sum()
        topk = df.groupby('qid',group_keys=False).apply(lambda x:x.nlargest(k,'score'))
        result = topk.groupby('qid').apply(lambda x:x[['score','docid']].to_numpy())
        D_hybrid = np.vstack([r[:,0] for r in result])
        I_hybrid = np.vstack([r[:,1] for r in result]).astype(int)
        return D_hybrid, I_hybrid

    D_dense, I_dense, cost_time_dense, cost_space_dense = dense_retrieval(q_vectors,d_vectors, dense_dim)

    if is_hybrid:
        D_sparse, I_sparse, cost_time_sparse, cost_space_sparse = sparse_retrieval(q_vectors,d_vectors, dense_dim, sparse_t)

    if is_hybrid:
        D_hybrid, I_hybrid = get_hybrid_result(D_dense, I_dense, D_sparse, I_sparse)
    else:
        D_hybrid, I_hybrid = D_dense, I_dense

    run = {}
    for q_idx in range(len(qids)):
        qid = qids[q_idx]
        doc_indices = I_hybrid[q_idx]
        scores = D_hybrid[q_idx]

        # 映射到文档ID
        run[qid] = {}
        for idx, doc_idx in enumerate(doc_indices):
            #跳过稀疏检索时填充的部分
            if doc_idx == -1:
                continue
            run[qid][dids[doc_idx]] = float(scores[idx])

    evaluator = RelevanceEvaluator(qrel, {'ndcg_cut_10', 'recall_1000'})
    results = evaluator.evaluate(run)
    ndcg_scores = [v['ndcg_cut_10'] for v in results.values()]
    recall_scores = [v['recall_1000'] for v in results.values()]
    if not is_hybrid:
        cost_time_sparse = 0
        cost_space_sparse = 0
    result = {
        "ndcg": float(np.mean(ndcg_scores)),
        "recall": float(np.mean(recall_scores)),
        "dense_cost_time": cost_time_dense,
        "dense_cost_space": cost_space_dense,
        "sparse_cost_time": cost_time_sparse,
        "sparse_cost_space": cost_space_sparse,
        "hybrid_cost_time": cost_time_dense + cost_time_sparse,
        "hybrid_cost_space": cost_space_dense + cost_space_sparse
    }

    return result


if __name__ == "__main__":
    query_num = 100
    config_path = os.path.join(PROJECT_ROOTPATH, "configs", "transform_config.json")
    with open(config_path) as f:
        config = json.load(f)
        MODEL_ROOTPATH = config["model_rootpath"]
        DATA_ROOTPATH = config["data_rootpath"]
        OUTPUT_ROOTPATH = config["output_rootpath"]
        # model_name = "gte-Qwen2-7B-instruct"
        # dataset_name = "nq"

    with open(os.path.join(PROJECT_ROOTPATH, "configs", "hybrid_retrieval_config.json")) as f:
        config = json.load(f)
        model_name = config["model"]
        dataset_name = config["dataset"]
        exps = config["exps"]
        is_hybrid = config["mode"] == "hybrid"
    #读取向量数据
    exp_name = f"{model_name}_{dataset_name}_exp"
    q_vector_path = os.path.join(OUTPUT_ROOTPATH, exp_name, "transform/q_vectors.h5")
    d_vector_path = os.path.join(OUTPUT_ROOTPATH, exp_name, "transform/d_vectors.h5")
    with h5py.File(q_vector_path, "r") as qhf, h5py.File(d_vector_path, "r") as dhf:
        qids = [qid.decode() for qid in qhf["ids"][:]]
        dids = [did.decode() for did in dhf["ids"][:]]
        qids = np.array(qids)
        dids = np.array(dids)
        num_docs = dids.shape[0]

        # 计算真实相关性字典
        qrel = {}
        qrel_path = os.path.join(DATA_ROOTPATH, dataset_name, "qrels/test.tsv")
        with open(qrel_path) as f:
            for line_idx, line in enumerate(f):
                if line_idx == 0: continue  # 跳过header
                qid, docid, score = line.strip().split("\t")
                qrel.setdefault(qid, {})[docid] = int(score)

        if query_num != -1:
            if len(qrel) > query_num:
                # qrel = dict(list(qrel.items())[:query_num])
                qrel = dict(random.sample(list(qrel.items()), query_num))

        total_dim = qhf["vectors"].shape[1]
        q_vectors_full = np.empty((qhf["vectors"].shape[0], qhf["vectors"].shape[1]), dtype="float32")
        d_vectors_full = np.empty((dhf["vectors"].shape[0], dhf["vectors"].shape[1]), dtype="float32")
        qhf["vectors"].read_direct(q_vectors_full)
        dhf["vectors"].read_direct(d_vectors_full)
        # 只筛选test set中的q
        valid_q_mask = np.array([qid in qrel for qid in qids])
        valid_q_indices = np.where(valid_q_mask)[0]
        qids = qids[valid_q_mask]
        q_vectors_full = q_vectors_full[valid_q_indices]

    results = []
    for exp in exps:
        dense_dim = exp["dense_dim"]
        sparse_t = exp["sparse_threshold"]
        exp_result = process_exp(q_vectors_full, d_vectors_full, qids, dids, qrel, dense_dim, sparse_t,is_hybrid)
        print(exp)
        print(exp_result)
        results.append({
            "model": model_name,
            "dataset": dataset_name,
            "hybrid":1 if is_hybrid else 0,
            "dense_dim": dense_dim,
            "sparse_t": sparse_t,
            "ndcg": exp_result["ndcg"],
            "dense_time": exp_result["dense_cost_time"],
            "dense_space": exp_result["dense_cost_space"],
            "sparse_time": exp_result["sparse_cost_time"],
            "sparse_space": exp_result["sparse_cost_space"],
            "hybrid_time": exp_result["hybrid_cost_time"],
            "hybrid_space": exp_result["hybrid_cost_space"],
            "result": exp_result
        })
    data_path = os.path.join(PROJECT_ROOTPATH,"result",f"hybrid_retrieval_subset.csv")
    with open(data_path,newline='',encoding='utf-8') as f:
        reader = csv.reader(f)
        header = next(reader)

    with open(data_path,'a',newline='',encoding='utf-8') as f:
        writer = csv.DictWriter(f,fieldnames=header,extrasaction='ignore')
        for result in results:
            writer.writerow(result)

    print("done")

