import os
import sys
import numpy as np
from multiprocessing import Process,set_start_method  
import torch
from diffusers import FluxKontextPipeline,FluxTransformer2DModel
from PIL import Image

def image_grid(imgs, rows, cols):
    assert len(imgs) == rows * cols
    w, h = imgs[0].size
    grid = Image.new("RGB", size=(cols * w, rows * h))
    for i, img in enumerate(imgs):
        grid.paste(img, box=(i % cols * w, i // cols * h))
    return grid

# 配置路径
BASE_MODEL = "FLUX.1-dev"
LORA_PATH = "mmface/checkpoint-5000/"
TEST_ROOT = "MM-Celeba-HQ/test_data/"
SAVE_ROOT = "MM-Celeba-HQ/mmface/"
os.environ["CUDA_VISIBLE_DEVICES"] = str(0)
device = torch.device(f"cuda:{0}")
# 初始化随机种子
seed = 42
generator = torch.Generator(device).manual_seed(seed)
# 独立加载模型到当前GPU
transformer = FluxTransformer2DModel.from_pretrained(
        BASE_MODEL, subfolder="transformer", torch_dtype=torch.bfloat16
)
pipe = FluxKontextPipeline.from_pretrained(
    BASE_MODEL,
    transformer=transformer,
    torch_dtype=torch.bfloat16,
    # 添加安全参数防止跨进程污染
    _fast_inference=False,
    use_safetensors=True
)
# 独立加载LoRA权重（添加适配器前缀）
pipe.load_lora_weights(
    LORA_PATH,
    is_main_process=False  # 禁用主进程检查
)
pipe.to(device)

#mask2face
for idx in range(27500,28800):
    mask_path = f"{TEST_ROOT}/mask/{idx}.png"
    output_path = f"{SAVE_ROOT}/mask2face/{idx}.jpg"

    prompt = ""

    try:
        with torch.cuda.device(device):
            mask_image = Image.open(mask_path).convert('RGB').resize((512, 512))

            grid = []
            grid.append(mask_image)
        
            for i in range(3):
                gen_img = pipe(
                    image=mask_image,
                    prompt=prompt,
                    height=512,
                    width=512,
                    num_inference_steps=28,
                    guidance_scale=1,
                    max_sequence_length=512,
                    generator=generator,
                ).images[0]
                grid.append(gen_img)
            print(len(grid))
            
            grid = image_grid(grid, 1, 4)  
            print(output_path)
            print(grid)
            # 保存到CPU内存
            grid.save(output_path)
            print(output_path)
    except Exception as e:
        torch.cuda.empty_cache()  # 清理当前GPU缓存

