import os
import sys
import numpy as np
from multiprocessing import Process,set_start_method  
import torch
from diffusers import FluxKontextPipeline,FluxTransformer2DModel
from PIL import Image

def image_grid(imgs, rows, cols):
    assert len(imgs) == rows * cols
    w, h = imgs[0].size
    grid = Image.new("RGB", size=(cols * w, rows * h))
    for i, img in enumerate(imgs):
        grid.paste(img, box=(i % cols * w, i // cols * h))
    return grid


BASE_MODEL = "FLUX.1-dev"
LORA_PATH = "mmface/checkpoint-5000/"
TEST_ROOT = "MM-Celeba-HQ/test_data/"
SAVE_ROOT = "MM-Celeba-HQ/mmface/"
os.environ["CUDA_VISIBLE_DEVICES"] = str(0)
device = torch.device(f"cuda:{0}")

seed = 42
generator = torch.Generator(device).manual_seed(seed)

transformer = FluxTransformer2DModel.from_pretrained(
        BASE_MODEL, subfolder="transformer", torch_dtype=torch.bfloat16
)
pipe = FluxKontextPipeline.from_pretrained(
    BASE_MODEL,
    transformer=transformer,
    torch_dtype=torch.bfloat16,
    _fast_inference=False,
    use_safetensors=True
)

pipe.load_lora_weights(
    LORA_PATH,
    is_main_process=False 
)
pipe.to(device)

#text-2-face
for idx in range(28500,28800):
    txt_path = f"{TEST_ROOT}/text/{idx}.txt"
    output_path = f"{SAVE_ROOT}/text2face/{idx}.jpg"


    with open(txt_path, "r") as f:
        prompt = f.readline().strip()
    mask_image = Image.new('RGB', (512, 512), (0, 0, 0))
    try:
        with torch.cuda.device(device):
            grid = []
            for i in range(3):
                gen_img = pipe(
                    image=mask_image,
                    prompt=prompt,
                    height=512,
                    width=512,
                    num_inference_steps=28,
                    guidance_scale=1,
                    max_sequence_length=512,
                    generator=generator,
                ).images[0]
                grid.append(gen_img)
            grid = image_grid(grid, 1, 3)  

            grid.save(output_path)
            print(output_path)
    except Exception as e:
        torch.cuda.empty_cache() 

        