#!/usr/bin/env python
import sys, json, random, numpy as np
from pathlib import Path
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline

def write_jsonl(path: Path, obj: dict):
    with open(path, "a", encoding="utf-8") as f:
        json.dump(obj, f, ensure_ascii=False)
        f.write("\n")

# ---------------------------------------------------------------------------
# Promptist (Microsoft)
# ---------------------------------------------------------------------------

def load_promptist(device="cpu"):
    tok = AutoTokenizer.from_pretrained("gpt2")
    tok.pad_token = tok.eos_token
    tok.padding_side = "left"
    model = AutoModelForCausalLM.from_pretrained("microsoft/Promptist").to(device).eval()
    return model, tok

def enhance_promptist(model, tok, prompt, device):
    ids = tok(f"{prompt} Rephrase:", return_tensors="pt").input_ids.to(device)
    with torch.no_grad():
        out = model.generate(ids, max_new_tokens=75, num_beams=8)
    return tok.decode(out[0], skip_special_tokens=True).replace(f"{prompt} Rephrase:", "").strip()

# ---------------------------------------------------------------------------
# Prompt‑Extend (DasPartho)
# ---------------------------------------------------------------------------

def load_extend(device="cpu"):
    return pipeline("text-generation", model="daspartho/prompt-extend", device=0 if device.startswith("cuda") else -1)

def enhance_extend(pipe, prompt):
    text = pipe(prompt, max_new_tokens=75)[0]["generated_text"]
    return text[len(prompt):].lstrip(", ").strip() if text.lower().startswith(prompt.lower()) else text.strip()

# ---------------------------------------------------------------------------
# MagicPrompt‑Dalle (Gustavosta)
# ---------------------------------------------------------------------------

def load_magic(device="cpu"):
    return pipeline("text-generation", model="Gustavosta/MagicPrompt-Dalle", device=0 if device.startswith("cuda") else -1)

def enhance_magic(pipe, prompt):
    return pipe(prompt, max_new_tokens=75, do_sample=True, temperature=0.9)[0]["generated_text"].strip()

# ---------------------------------------------------------------------------
# main
# ---------------------------------------------------------------------------

def main():
    if len(sys.argv) < 2:
        print("Usage: python enhance_prompts.py <prompts.jsonl>")
        sys.exit(1)

    input_path = Path(sys.argv[1])
    device = "cuda" if torch.cuda.is_available() else "cpu"

    torch.manual_seed(42)
    random.seed(42)
    np.random.seed(42)

    p_model, p_tok = load_promptist(device)
    e_pipe = load_extend(device)
    m_pipe = load_magic(device)

    outputs = {
        "promptist": Path("promptist.jsonl"),
        "promptextend": Path("promptextend.jsonl"),
        "promptdb": Path("promptdb.jsonl"),  # kept name for backward‑compatibility
    }
    for p in outputs.values():
        if p.exists():
            p.unlink()

    with open(input_path, encoding="utf-8") as infile:
        for idx, line in enumerate(infile, start=1):
            try:
                record = json.loads(line)
            except json.JSONDecodeError:
                continue

            prompt = record.get("prompt", "").strip()
            if not prompt:
                continue

            write_jsonl(outputs["promptist"], {"prompt": prompt, "enhanced": enhance_promptist(p_model, p_tok, prompt, device)})
            write_jsonl(outputs["promptextend"], {"prompt": prompt, "enhanced": enhance_extend(e_pipe, prompt)})
            write_jsonl(outputs["promptdb"], {"prompt": prompt, "enhanced": enhance_magic(m_pipe, prompt)})
            print(f"[{idx}] {prompt}")

if __name__ == "__main__":
    main()
