import os
from PIL import Image
from diffusers import StableDiffusionPipeline
import torch


def sd21_text2img(prompt, output_dir, load_lora1):
    pipe = StableDiffusionPipeline.from_pretrained(
        "stabilityai/stable-diffusion-2-1-base", 
        torch_dtype=torch.float16, 
        local_files_only=True,
    )
    if load_lora1:
        
        lora1_path = 'resources/backup/lora_ckpt_exp6/aki-000400.safetensors'
        # lora_path = os.path.join(lora_dir, lora_name)
        print(f"SF3D img2mesh load lora1: {lora1_path}")
        pipe.load_lora_weights(lora1_path, adapter_name="lora1")
    pipe.to("cuda")
    generator = torch.Generator("cuda").manual_seed(1000)
    images = pipe(prompt, generator=generator).images
    # print(f"images: {len(images)}")

    img_path = os.path.join(output_dir, f"{str(prompt.replace(' ', '_'))}.png")
    # Image.fromarray(images[0]).save(img_path, format='PNG')
    images[0].save(img_path, format='PNG')
    del pipe
    torch.cuda.empty_cache()

    return img_path