import sys, json
import kgen.models as models
import kgen.executor.tipo as tipo
from kgen.formatter import seperate_tags, apply_format
from kgen.generate import generate
from kgen.executor.tipo import (
    parse_tipo_request,
    tipo_single_request,
    tipo_runner,
    apply_tipo_prompt,
    parse_tipo_result,
    OPERATION_LIST,
)
from kgen.formatter import seperate_tags, apply_format

if len(sys.argv) < 2:
    print("Usage: python enhance_prompts.py <input_prompts.jsonl>")
    sys.exit(1)
input_path = sys.argv[1]

models.load_model(
    "TIPO-Anonymous/TIPO-500M-ft",           
    device="cuda",               
)
with open(input_path, 'r') as infile:
    for idx, line in enumerate(infile, start=1):
        data = json.loads(line)
        prompt_text = data.get("prompt", "") + ", safe"
        if not prompt_text:
            continue  # skip empty prompt lines
        raw_prompt = prompt_text.strip()                      # e.g. "bond"
        user_tags   = [t.strip() for t in raw_prompt.split(",") if t.strip()]
        org_tag_map = seperate_tags(user_tags)                # ↳ {"general": ["bond"]}
        meta, operations, general, nl_prompt = parse_tipo_request(
            org_tag_map,
            "",  
            tag_length_target="long",
            nl_length_target="long",
            generate_extra_nl_prompt=True,
        )

        meta["aspect_ratio"] = "1.0"

        tag_map, _ = tipo_runner(  
            meta,
            operations,
            general,
            nl_prompt,
            temperature=0.6,  
            seed=42 + idx, 
            top_p=0.95,
            min_p=0.05,
            top_k=80,
        )

        TIPO_FORMAT = """<|special|>, 
        <|characters|>, <|copyrights|>, 

        <|general|>,

        <|extended|>.

        <|quality|>, <|meta|>, <|rating|>"""

        enhanced_prompt = apply_format(tag_map, TIPO_FORMAT)

        output_path = f"tipo_full.jsonl"
        with open(output_path, 'a') as outfile:
            json.dump({"prompt": enhanced_prompt}, outfile)
            outfile.write("\n")
        print(f"Saved enhanced prompt to {output_path}")
