import sys
import json
import random
import numpy as np
from pathlib import Path

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

# ---------------------------------------------------------------------------
# helpers
# ---------------------------------------------------------------------------

def write_jsonl(path: Path, obj: dict):
    """Append a JSON-serialisable object to a jsonl file."""
    with open(path, "a", encoding="utf-8") as f:
        json.dump(obj, f, ensure_ascii=False)
        f.write("\n")

# ---------------------------------------------------------------------------
# BeautifulPrompt (Alibaba-PAI)
# ---------------------------------------------------------------------------

MODEL_NAME = "alibaba-pai/pai-bloom-1b1-text2prompt-sd"


def load_beautifulprompt(device: str = "cpu"):
    tok = AutoTokenizer.from_pretrained(MODEL_NAME)
    model = AutoModelForCausalLM.from_pretrained(MODEL_NAME).to(device).eval()
    return model, tok


def enhance_beautifulprompt(model, tok, prompt: str, device: str, num_return_sequences: int = 5):
    """Generate Stable-Diffusion prompts from a short description.

    Returns a *list* of `num_return_sequences` strings.
    """
    # The model was trained with an instruction prefix as shown in the model card.
    inp = (
        "Instruction: Give a simple description of the image to generate a drawing prompt.\n"
        f"Input: {prompt}\n"
        "Output:"
    )
    input_ids = tok.encode(inp, return_tensors="pt").to(device)
    with torch.no_grad():
        outputs = model.generate(
            input_ids,
            max_length=384,
            do_sample=True,
            temperature=1.0,
            top_k=50,
            top_p=0.95,
            repetition_penalty=1.2,
            num_return_sequences=num_return_sequences,
        )
    decoded = tok.batch_decode(outputs[:, input_ids.size(1):], skip_special_tokens=True)
    return [t.strip() for t in decoded]

# ---------------------------------------------------------------------------
# main
# ---------------------------------------------------------------------------

def main():
    if len(sys.argv) < 2:
        print("Usage: python bloom.py <prompts.jsonl>")
        sys.exit(1)

    input_path = Path(sys.argv[1])
    if not input_path.exists():
        print(f"Input file {input_path} not found.")
        sys.exit(1)

    device = "cuda" if torch.cuda.is_available() else "cpu"
    torch.manual_seed(42)
    random.seed(42)
    np.random.seed(42)

    model, tok = load_beautifulprompt(device)

    output_path = Path("./prompts/beautifulprompt.jsonl")
    if output_path.exists():
        output_path.unlink()

    with open(input_path, encoding="utf-8") as infile:
        for idx, line in enumerate(infile, start=1):
            try:
                record = json.loads(line)
            except json.JSONDecodeError:
                print(f"[{idx}] skipped – not valid JSON")
                continue

            prompt = record.get("prompt", "").strip()
            if not prompt:
                print(f"[{idx}] skipped – empty prompt")
                continue

            enhanced_prompts = enhance_beautifulprompt(model, tok, prompt, device)
            write_jsonl(output_path, {"prompt": prompt, "enhanced": enhanced_prompts})
            print(f"[{idx}] {prompt}")


if __name__ == "__main__":
    main()
