import torch
from sd3.pipeline_stable_diffusion_3 import StableDiffusion3Pipeline
from sd3.scheduling_flow_match_euler_discrete import FlowMatchEulerDiscreteScheduler
import random
import os
import glob
import json

device = "cuda"
pipe = StableDiffusion3Pipeline.from_pretrained("stabilityai/stable-diffusion-3.5-large", torch_dtype=torch.bfloat16)
pipe.to(device)
scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained("stabilityai/stable-diffusion-3.5-large", subfolder="scheduler", torch_dtype=torch.bfloat16)
pipe.scheduler = scheduler

for idx in range(5000):
    base_name = f"{idx:05d}"
    file_pattern = os.path.join("/path/to/output/cofusion", f"{base_name}*.jpg") 
    matching_files = glob.glob(file_pattern)
    image_path = matching_files[0] 
    parts = image_path.split('/')  
    image_name = parts[-1] 
    file_name = image_name.split('.')[0]
    latents_path = os.path.join("/path/to/output/noise", f"{file_name}.pt") 

    style_json_path = os.path.join("/path/to/reflection.json")           
    with open(style_json_path, 'r', encoding='utf-8') as f:
        style_info = json.load(f)
    style = style_info[idx][f'{idx}']
    print(style)



    json_path = os.path.join("/path/to/prompt", f"{idx:05d}.json")           
    with open(json_path, 'r', encoding='utf-8') as f:
        info = json.load(f)
    prompt = info['caption']
    prompt = append_style(prompt)
    
    print(prompt)
    parts = file_name.split('_')  
    last_part = parts[-1] 
    print(last_part)
    seed = int(last_part)

    latents = torch.load(latents_path).to(device)

    images_, _, noise_, predict_ = pipe(
        prompt = prompt,
        negative_prompt = style,
        num_inference_steps=28,
        height=1024,
        width=1024,
        guidance_scale=3.5,
        mid_stage=7,
        generator=torch.Generator("cpu").manual_seed(seed),
        continue_to_gen=True,
        mid_latents=latents
    )
    image = images_.images[0]
    os.makedirs("/path/to/output", exist_ok=True)    
    image.save(f"/path/to/output/{file_name}.jpg")
