import torch
from sd3.pipeline_stable_diffusion_3 import StableDiffusion3Pipeline
from sd3.scheduling_flow_match_euler_discrete import FlowMatchEulerDiscreteScheduler
import random
import os
import glob
import json

pipe = StableDiffusion3Pipeline.from_pretrained("stabilityai/stable-diffusion-3.5-large", torch_dtype=torch.bfloat16)
pipe.to("cuda")
scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained("stabilityai/stable-diffusion-3.5-large", subfolder="scheduler", torch_dtype=torch.bfloat16)
pipe.scheduler = scheduler


with torch.no_grad():
    for i in range(5000):
        json_path = os.path.join("/path/to/prompt", f"{i:05d}.json") 
        with open(json_path, 'r', encoding='utf-8') as f:
            info = json.load(f)
        prompt = info['caption']
        base_name = f"{i:05d}"
        for j in range(10):
            seed = random.randint(0, 10000)
            images_, _, noise_latents_, predict_ = pipe(
                prompt=prompt,
                num_inference_steps=28,
                height=1024,
                width=1024,
                guidance_scale=3.5,
                mid_k=7,
                if_break=True,
                generator=torch.Generator("cpu").manual_seed(seed),
            )
        
            predict=predict_.images[0]
            os.makedirs("/path/to/output/noise", exist_ok=True)    
            torch.save(noise_latents_.cpu(), f"/path/to/output/noise/{i:04d}_{j}_{seed}.pt")
            os.makedirs("/path/to/output/pred", exist_ok=True)    
            predict.save(f"/path/to/output/pred/{i:04d}_{j}_{seed}.jpg")
            print("{}/{}:{}".format(i, 5000, j))