#!/usr/bin/env python
# coding=utf-8
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and

import argparse
import logging
import math
import os
import shutil
from pathlib import Path

import torch
from PIL import Image
from tqdm.auto import tqdm

from diffusers import AutoencoderKL, FluxPipeline, FluxTransformer2DModel
from diffusers.utils import check_min_version

# 检查diffusers库的最低版本要求
check_min_version("0.33.0.dev0")

logger = logging.getLogger(__name__)


def parse_args():
    parser = argparse.ArgumentParser(description="Simple example of a inference script for Flux LoRA DreamBooth.")
    parser.add_argument(
        "--pretrained_model_name_or_path",
        type=str,
        default="black-forest-labs/FLUX.1-dev",
        help="Path to pretrained model or model identifier from huggingface.co/models.",
    )
    #     "--lora_weights_path",
    #     type=str,
    #     default="./trained-flux-lora/",
    #     help="Path to the trained LoRA weights.",
    # )
    #     "--lora_weight_name",
    #     type=str,
    #     default="pytorch_lora_weights.safetensors",
    #     help="Name of the LoRA weight file.",
    # )
    parser.add_argument(
        "--prompt",
        type=str,
        required=True,
        help="The prompt to generate images from.",
    )
    parser.add_argument(
        "--output_dir",
        type=str,
        default="inference_results",
        help="The output directory where the generated images will be written.",
    )
    parser.add_argument(
        "--num_inference_steps",
        type=int,
        default=28,
        help="Number of inference steps.",
    )
    parser.add_argument(
        "--guidance_scale",
        type=float,
        default=3.5,
        help="Guidance scale for the diffusion process.",
    )
    parser.add_argument(
        "--num_images_per_prompt",
        type=int,
        default=10,
        help="Number of images to generate per prompt.",
    )
    parser.add_argument(
        "--height",
        type=int,
        default=1024,
        help="The height in pixels of the generated image.",
    )
    parser.add_argument(
        "--width",
        type=int,
        default=1024,
        help="The width in pixels of the generated image.",
    )
    parser.add_argument(
        "--seed",
        type=int,
        default=None,
        help="A seed for reproducible inference.",
    )
    parser.add_argument(
        "--variant",
        type=str,
        default=None,
        help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16",
    )

    args = parser.parse_args()
    return args


def main():
    args = parse_args()
    
    # 创建输出目录
    os.makedirs(args.output_dir, exist_ok=True)
    
    # 设置日志
    logging.basicConfig(level=logging.INFO)
    
    # 设置设备
    device = "cuda" if torch.cuda.is_available() else "cpu"
    logger.info(f"Using device: {device}")
    
    # 设置随机种子
    if args.seed is not None:
        torch.manual_seed(args.seed)
        if device == "cuda":
            torch.cuda.manual_seed(args.seed)

    # 加载模型
    logger.info("Loading pipeline...")
    pipeline = FluxPipeline.from_pretrained(
        args.pretrained_model_name_or_path,
        torch_dtype=torch.bfloat16 if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else torch.float16,
        variant=args.variant,
    )
    
    
    # 将管道移动到设备
    pipeline = pipeline.to(device)
    
    # 设置生成参数
    logger.info(f"Generating images with prompt: {args.prompt}")
    
    # 生成图像
    images = pipeline(
        prompt=args.prompt,
        num_inference_steps=args.num_inference_steps,
        guidance_scale=args.guidance_scale,
        num_images_per_prompt=args.num_images_per_prompt,
        height=args.height,
        width=args.width,
        generator=torch.Generator(device=device).manual_seed(args.seed) if args.seed is not None else None,
    ).images
    
    # 保存图像
    logger.info(f"Saving {len(images)} images to {args.output_dir}")
    for i, image in enumerate(images):
        image_path = os.path.join(args.output_dir, f"image_{i+1}.png")
        image.save(image_path)
        logger.info(f"Saved image {i+1} to {image_path}")
    
    logger.info("Inference completed!")


if __name__ == "__main__":
    main()
