import argparse
import os
import json
import numpy as np
from FlagEmbedding import FlagICLModel, FlagModel
import jsonlines
from tqdm import tqdm
from utils import load_jsonl, save_to_hdf5

PROJECT_ROOTPATH = "/rootpath/adaptiveLengthEmbedding"
MODEL_ROOTPATH = ""
DATA_ROOTPATH = ""
OUTPUT_ROOTPATH = "",
def main(model_name, dataset_name):
    exp_dir = f"{model_name}_{dataset_name}_exp"
    model_path = os.path.join(MODEL_ROOTPATH, model_name)
    data_path = os.path.join(DATA_ROOTPATH, dataset_name)
    output_path = os.path.join(OUTPUT_ROOTPATH, exp_dir, "origin")

    # 加载数据
    corpus = load_jsonl(os.path.join(data_path, "corpus.jsonl"))
    queries = load_jsonl(os.path.join(data_path, "queries.jsonl"))

    # 按文档长度排序（降序排列）
    corpus = sorted(corpus,
                   key=lambda doc: len(doc['text'].split()),
                   reverse=True)  # 降序排列使长文档在前

    if model_name == "bge_en_icl":
        model = FlagICLModel(
            model_path,
            query_instruction_for_retrieval="retrieve relevant passages that answer the query.",
            examples_for_task=None,
            use_fp16=True
        )
    else:
        model = FlagModel(model_path,
                  query_instruction_for_retrieval="Represent this sentence for searching relevant passages:",
                  use_fp16=True)


    # 编码查询和文档（保持排序后的文档顺序）
    query_texts = [q["text"] for q in queries]
    query_ids = [q["_id"] for q in queries]
    doc_texts = [f"{doc['title']} {doc['text']}" if 'title' in doc else doc['text'] for doc in corpus]
    doc_ids = [doc["_id"] for doc in corpus]

    query_embeddings = model.encode_queries(query_texts, batch_size=64,convert_to_numpy=True)
    doc_embeddings = model.encode_corpus(doc_texts, batch_size=64,convert_to_numpy=True)
    os.makedirs(output_path, exist_ok=True)
    save_to_hdf5(os.path.join(output_path, "q_vectors.h5"), query_ids, query_embeddings)
    save_to_hdf5(os.path.join(output_path, "d_vectors.h5"), doc_ids, doc_embeddings)

if __name__ == "__main__":
    parser = argparse.ArgumentParser()

    # 添加命令行参数
    parser.add_argument("--model_name", type=str, required=True)
    parser.add_argument("--dataset_name", type=str, required=True)

    # 解析命令行参数
    args = parser.parse_args()
    config_path = os.path.join(PROJECT_ROOTPATH, "configs", "text_encode_config.json")
    with open(config_path) as f:
        config = json.load(f)
        MODEL_ROOTPATH = config["model_rootpath"]
        DATA_ROOTPATH = config["data_rootpath"]
        OUTPUT_ROOTPATH = config["output_rootpath"]

    main(args.model_name, args.dataset_name)
