from re import S
import torch
import os
from PIL import Image
import numpy as np
from diffusers import StableDiffusion3Pipeline
from flow_grpo.diffusers_patch.sd3_pipeline_with_logprob import pipeline_with_logprob
import importlib
from peft import PeftModel
import random
import shutil

model_id = "SD3.5-medium"
device = "cuda"
save_dir = "check_img"
if not os.path.exists(save_dir):
    os.makedirs(save_dir)
else:
    shutil.rmtree(save_dir)
    os.makedirs(save_dir)


lora_path = ""


pipe = StableDiffusion3Pipeline.from_pretrained(model_id, torch_dtype=torch.float16)
pipe.tranformer = PeftModel.from_pretrained(pipe.transformer, lora_path)
pipe = pipe.to(device)
prompt_list = [
    'A little cute girl similing, painting style.',
]
for i in range(10):
    seed = random.randint(1, 10000) # 48
    prompt = prompt_list[i % len(prompt_list)]
    image = pipe(
        prompt,
        # negative_prompt= "blurry, low quality, low resolution, bad anatomy, extra limbs, disfigured",
        num_inference_steps=28,
        guidance_scale=4.5,
        generator=torch.Generator(device).manual_seed(seed),
    ).images[0]
    image.save(f'{save_dir}/{i}_{seed}.png') 