import os
import torch
from diffusers import FluxKontextPipeline,FluxTransformer2DModel
from diffusers.utils import load_image

transformer = FluxTransformer2DModel.from_pretrained(
        "FLUX.1-dev", subfolder="transformer", torch_dtype=torch.bfloat16
)

pipeline = FluxKontextPipeline.from_pretrained(
    "FLUX.1-dev",
    transformer=transformer,
    torch_dtype=torch.bfloat16,
)
pipeline.load_lora_weights("./mmface/checkpoint-5000/")
pipeline.to(torch.device("cuda"))


TEST_ROOT = "MM-Celeba-HQ/test_data/"
SAVE_ROOT = "MM-Celeba-HQ/mmface/"

generator = torch.Generator(torch.device("cuda")).manual_seed(42)

for idx in range(27000,30000) :
    mask_path = os.path.join(TEST_ROOT,"mask", f"{idx}.png")
    text_path = os.path.join(TEST_ROOT,"text", f"{idx}.txt")

    with open(text_path, "r") as f:
        prompt = f.readline().strip()

    image = load_image(mask_path).convert("RGB")
    image = image.resize((512, 512))
    
    image = pipeline(
                image=image,
                prompt=prompt,
                height=512,
                width=512,
                num_inference_steps=28,
                guidance_scale=1,
                max_sequence_length=512,
                generator=generator,
            ).images[0]

    save_path = os.path.join(SAVE_ROOT, f"{idx}.jpg")
    image.save(save_path)