import json
from argparse import ArgumentParser
from pathlib import Path

import numpy as np
from datasets import Dataset, load_from_disk
from vllm import LLM

TASK_DESCRIPTION = "Find similar texts, according to both content, writing style, etc."


def get_detailed_instruct(query: str) -> str:
    return f"Instruct: {TASK_DESCRIPTION}\nText:{query}"


if __name__ == "__main__":
    parser = ArgumentParser()

    parser.add_argument(
        "--model_name",
        type=str,
        default="models_dir/qwen3-embed-4",
        help="HuggingFace model name or path.",
    )
    parser.add_argument(
        "--ref_data_path",
        type=str,
        default="data_dir/wikibio",
        help="Path to input JSONL file with 'content' field.",
    )
    parser.add_argument(
        "--ref_embed_path",
        type=str,
        default="results_dir/wikibio/ref_embeddings.npy",
        help="Path to output directory for reference embeddings.",
    )

    parser.add_argument(
        "--gen_data_path",
        type=str,
        default="data_dir/wikibio_gen",
        help="Path to input JSONL file with 'content' field.",
    )

    parser.add_argument(
        "--output_path",
        type=str,
        default="results_dir/wikibio/generation.jsonl",
        help="Path to output directory for generated samples.",
    )
    parser.add_argument(
        "--embed_ref", action="store_true", help="Whether to embed reference data."
    )

    args = parser.parse_args()
    model = LLM(model=args.model_name, task="embed")

    if args.embed_ref:
        print("Embedding reference data...")
        # load reference and generation datasets
        ref_dataset = load_from_disk(args.ref_data_path)
        ref_data = ref_dataset["content"]

        # Each query must come with a one-sentence instruction that describes the task

        ref_embeddings = model.embed(ref_data)  # for testing, only embed 500 samples
        ref_embeddings = [item.outputs.embedding for item in ref_embeddings]
        ref_dataset = ref_dataset.add_column("embedding", ref_embeddings)
        Path(args.ref_embed_path).parent.mkdir(parents=True, exist_ok=True)
        ref_dataset.save_to_disk(args.ref_embed_path)
        print(
            f"Saved {len(ref_data)} reference texts and embeddings to {args.ref_embed_path}"
        )

    # gen dataset is a jsonl file, each line is a text
    with open(args.gen_data_path, "r") as f:
        gen_data = [json.loads(line) for line in f]

    gen_texts = [item["text"] for item in gen_data]
    gen_gender = [item["gender"] for item in gen_data]

    print(f"Loaded {len(gen_texts)} samples from {args.gen_data_path}")

    # Generate embeddings for the new data
    gen_embeddings = model.embed(gen_texts)
    gen_embeddings = [item.outputs.embedding for item in gen_embeddings]
    dataset = {"text": gen_texts, "embedding": gen_embeddings, "gender": gen_gender}
    dataset = Dataset.from_dict(dataset)
    Path(args.output_path).parent.mkdir(parents=True, exist_ok=True)

    dataset.save_to_disk(args.output_path)
    print(f"Saved texts and generated embeddings to {args.output_path}")
