import os
import sys
import numpy as np
from multiprocessing import Process,set_start_method  
import torch
from diffusers import FluxKontextPipeline,FluxTransformer2DModel
from PIL import Image



FLUX_MODEL_PATH = "FLUX.1-dev"
MMFACE_LORA_PATH = "mmface/checkpoint-5000/"
TEST_ROOT = "MM-Celeba-HQ/groundtruth/"
SAVE_ROOT = "MM-FFHQ/EC2Face/face/"

def process_files(gpu_id, file_list,save_root,lora_path):
    """处理指定GPU上的文件列表（改进版）"""
    try:

        os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_id)
        device = torch.device(f"cuda:{gpu_id}")
        

        generator = torch.Generator(device=device)
        generator.manual_seed(42)


        transformer = FluxTransformer2DModel.from_pretrained(
                FLUX_MODEL_PATH, subfolder="transformer", torch_dtype=torch.bfloat16
        )

        pipe = FluxKontextPipeline.from_pretrained(
            FLUX_MODEL_PATH,
            transformer=transformer,
            torch_dtype=torch.bfloat16,

            _fast_inference=False,
            use_safetensors=True
        )

        pipe.load_lora_weights(
            lora_path,
            is_main_process=False 
        )
        pipe.to(device)
        for mask_name in file_list:
 
            output_path = os.path.join(save_root, mask_name.replace(".png", ".jpg"))
            if os.path.exists(output_path):
                continue

   
            mask_path = os.path.join(TEST_ROOT, "mask", mask_name)
            txt_path = os.path.join(TEST_ROOT, "text", mask_name.replace(".png", ".txt"))
            

            try:

                with torch.cuda.device(device):
                    
                    mask_image = Image.open(mask_path).convert('RGB').resize((512, 512))
                    
                    
                    with open(txt_path, "r") as f:
                        prompt = f.readline().strip()
    

                   
                    torch.cuda.synchronize(device)

                   
                    gen_img = pipe(
                            image=mask_image,
                            prompt=prompt,
                            height=512,
                            width=512,
                            num_inference_steps=28,
                            guidance_scale=1,
                            max_sequence_length=512,
                            generator=torch.Generator(device).manual_seed(42),
                        ).images[0]

                    
                    torch.cuda.synchronize(device)

                    
                    gen_img.save(output_path)
                    print(f"[GPU{gpu_id}] 处理完成: {mask_name} -> {output_path}")
                    

            except Exception as e:
                print(f"[GPU{gpu_id}] 处理失败 {mask_name}: {str(e)}")
                torch.cuda.empty_cache()  

    except Exception as e:
        print(f"[GPU{gpu_id}] 进程启动失败: {str(e)}")

if __name__ == "__main__":
    lora_path = sys.argv[1]
    save_root = sys.argv[2]

    
    print(lora_path,save_root)
    os.makedirs(os.path.dirname(save_root), exist_ok=True)

    set_start_method('spawn', force=True)  



    mask_dir = os.path.join(TEST_ROOT, "mask")
    all_files = [f for f in os.listdir(mask_dir) if f.endswith(".png")]
    

    pending_files = [
        f for f in all_files 
        if not os.path.exists(os.path.join(save_root, f.replace(".png", ".jpg")))
    ]

    if not pending_files:
        print("所有文件已处理完成")
        sys.exit(0)


    num_gpus = torch.cuda.device_count()
    if num_gpus == 0:
        print("错误：未检测到可用GPU")
        sys.exit(1)


    file_chunks = np.array_split(pending_files, num_gpus)
    

    processes = []
    for gpu_id in range(num_gpus):
        chunk = file_chunks[gpu_id].tolist()
        p = Process(target=process_files, args=(gpu_id, chunk,save_root,lora_path))
        processes.append(p)
        p.start()
        print(f"启动进程 {gpu_id} 处理 {len(chunk)} 个文件")


    for p in processes:
        p.join()

    print("所有任务处理完成")
