import argparse
import os
import json
import numpy as np
from FlagEmbedding import FlagICLModel, FlagModel
import jsonlines
from tqdm import tqdm
from utils import load_jsonl, save_to_hdf5
import h5py

PROJECT_ROOTPATH = "/rootpath/adaptiveLengthEmbedding"
MODEL_ROOTPATH = ""
DATA_ROOTPATH = ""
OUTPUT_ROOTPATH = "",
def main(model_name, dataset_name):
    exp_dir = f"{model_name}_{dataset_name}_exp"
    model_path = os.path.join(MODEL_ROOTPATH, model_name)
    data_path = os.path.join(DATA_ROOTPATH, dataset_name)
    output_path = os.path.join(OUTPUT_ROOTPATH, exp_dir, "origin")

    # 加载数据
    corpus = load_jsonl(os.path.join(data_path, "corpus.jsonl"))
    queries = load_jsonl(os.path.join(data_path, "queries.jsonl"))

    # 按文档长度排序（降序排列）
    corpus = sorted(corpus,
                   key=lambda doc: len(doc['text'].split()),
                   reverse=True)  # 降序排列使长文档在前

    if model_name == "bge_en_icl":
        model = FlagICLModel(
            model_path,
            query_instruction_for_retrieval="retrieve relevant passages that answer the query.",
            examples_for_task=None,
            use_fp16=True
        )
    else:
        model = FlagModel(model_path,
                  query_instruction_for_retrieval="Represent this sentence for searching relevant passages:",
                  use_fp16=True)


    # 编码查询和文档（保持排序后的文档顺序）
    query_texts = [q["text"] for q in queries]
    query_ids = [q["_id"] for q in queries]
    query_embeddings = model.encode_queries(query_texts, batch_size=64, convert_to_numpy=True)
    # 创建输出目录
    os.makedirs(output_path, exist_ok=True)
    save_to_hdf5(os.path.join(output_path, "q_vectors.h5"), query_ids, query_embeddings)

    doc_texts = [f"{doc['title']} {doc['text']}" if 'title' in doc else doc['text'] for doc in corpus]
    doc_ids = [doc["_id"] for doc in corpus]
    num_docs = len(doc_texts)

    # 获取嵌入维度
    sample_embedding = model.encode_corpus([doc_texts[0]], batch_size=1, convert_to_numpy=True)
    embedding_dim = sample_embedding.shape[1]

    # 分批次编码并保存文档向量
    output_doc_file = os.path.join(output_path, "d_vectors.h5")
    with h5py.File(output_doc_file, 'w') as hf:
        embeddings_ds = hf.create_dataset(
            'vectors',
            shape=(num_docs, embedding_dim),
            dtype=np.float32
        )
        # 写入文档ID
        hf.create_dataset('ids', data=[id.encode() for id in doc_ids], dtype=h5py.string_dtype())

        # 分批处理文档
        batch_size = 64
        for start_idx in tqdm(
                range(0, num_docs, 128_000),
                desc="Encoding documents in batches"
        ):
            end_idx = start_idx + 128_000
            batch_texts = doc_texts[start_idx:end_idx]

            # 编码当前batch
            batch_embeddings = model.encode_corpus(
                batch_texts,
                batch_size=batch_size,
                convert_to_numpy=True
            )

            # 写入当前batch的嵌入结果
            embeddings_ds[start_idx:end_idx] = batch_embeddings


if __name__ == "__main__":
    parser = argparse.ArgumentParser()

    # 添加命令行参数
    parser.add_argument("--model_name", type=str, required=True)
    parser.add_argument("--dataset_name", type=str, required=True)

    # 解析命令行参数
    args = parser.parse_args()
    config_path = os.path.join(PROJECT_ROOTPATH, "configs", "text_encode_config.json")
    with open(config_path) as f:
        config = json.load(f)
        MODEL_ROOTPATH = config["model_rootpath"]
        DATA_ROOTPATH = config["data_rootpath"]
        OUTPUT_ROOTPATH = config["output_rootpath"]

    main(args.model_name, args.dataset_name)
