import argparse
import os
import json
import numpy as np
import torch
from sentence_transformers import SentenceTransformer
from tqdm import tqdm
from utils import load_jsonl, save_to_hdf5

PROJECT_ROOTPATH = "/rootpath/adaptiveLengthEmbedding"
MODEL_ROOTPATH = ""
DATA_ROOTPATH = ""
OUTPUT_ROOTPATH = "",
def main(model_name, dataset_name):
    exp_dir = f"{model_name}_{dataset_name}_exp"
    model_path = os.path.join(MODEL_ROOTPATH, model_name)
    data_path = os.path.join(DATA_ROOTPATH, dataset_name)
    output_path = os.path.join(OUTPUT_ROOTPATH, exp_dir, "origin")

    # 加载数据
    corpus = load_jsonl(os.path.join(data_path, "corpus.jsonl"))
    queries = load_jsonl(os.path.join(data_path, "queries.jsonl"))

    # 按文档长度排序（降序排列）
    corpus = sorted(corpus,
                   key=lambda doc: len(doc['text'].split()),
                   reverse=True)  # 降序排列使长文档在前

    model = SentenceTransformer(
        model_path,
        trust_remote_code=True,
        model_kwargs={"torch_dtype": torch.float16}
    )
    model.max_seq_length = 2048  # 设置最大序列长度

    # 编码查询和文档（保持排序后的文档顺序）
    query_texts = [q["text"] for q in queries]
    query_ids = [q["_id"] for q in queries]
    doc_texts = [f"{doc['title']} {doc['text']}" if 'title' in doc else doc['text'] for doc in corpus]
    doc_ids = [doc["_id"] for doc in corpus]
    pool = model.start_multi_process_pool()
    if model_name == "gte-Qwen2-7B-instruct":
        query_embeddings = model.encode_multi_process(query_texts,pool=pool, batch_size=32,prompt_name="query",show_progress_bar=True,normalize_embeddings=True)
    else:
        query_embeddings = model.encode_multi_process(query_texts,pool=pool, batch_size=32,show_progress_bar=True,normalize_embeddings=True)
    doc_embeddings = model.encode_multi_process(doc_texts,pool=pool, batch_size=32,show_progress_bar=True,normalize_embeddings=True)
    os.makedirs(output_path, exist_ok=True)
    save_to_hdf5(os.path.join(output_path, "q_vectors.h5"), query_ids, query_embeddings)
    save_to_hdf5(os.path.join(output_path, "d_vectors.h5"), doc_ids, doc_embeddings)

if __name__ == "__main__":
    parser = argparse.ArgumentParser()

    # 添加命令行参数
    parser.add_argument("--model_name", type=str, required=True)
    parser.add_argument("--dataset_name", type=str, required=True)

    # 解析命令行参数
    args = parser.parse_args()
    config_path = os.path.join(PROJECT_ROOTPATH, "configs", "text_encode_config.json")
    with open(config_path) as f:
        config = json.load(f)
        MODEL_ROOTPATH = config["model_rootpath"]
        DATA_ROOTPATH = config["data_rootpath"]
        OUTPUT_ROOTPATH = config["output_rootpath"]

    main(args.model_name, args.dataset_name)
