#!/usr/bin/env python
# coding=utf-8
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You not obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and

import argparse
import logging
import math
import os
import shutil
import csv
from pathlib import Path

import torch
from PIL import Image
from tqdm.auto import tqdm

from diffusers import AutoencoderKL, FluxPipeline, FluxTransformer2DModel
from diffusers.utils import check_min_version

# 检查diffusers库的最低版本要求
check_min_version("0.33.0.dev0")

logger = logging.getLogger(__name__)


def parse_args():
    parser = argparse.ArgumentParser(description="Simple example of a inference script for Flux LoRA DreamBooth.")
    parser.add_argument(
        "--pretrained_model_name_or_path",
        type=str,
        default="black-forest-labs/FLUX.1-dev",
        help="Path to pretrained model or model identifier from huggingface.co/models.",
    )
    parser.add_argument(
        "--lora_weights_path",
        type=str,
        default=None,
        help="Path to the trained LoRA weights.",
    )
    parser.add_argument(
        "--lora_weight_name",
        type=str,
        default="pytorch_lora_weights.safetensors",
        help="Name of the LoRA weight file.",
    )
    parser.add_argument(
        "--prompt_csv",
        type=str,
        required=True,
        help="Path to the CSV file containing image_id and prompt columns.",
    )
    parser.add_argument(
        "--output_dir",
        type=str,
        default="inference_results",
        help="The output directory where the generated images will be written.",
    )
    parser.add_argument(
        "--num_inference_steps",
        type=int,
        default=28,
        help="Number of inference steps.",
    )
    parser.add_argument(
        "--guidance_scale",
        type=float,
        default=3.5,
        help="Guidance scale for the diffusion process.",
    )
    parser.add_argument(
        "--num_images_per_prompt",
        type=int,
        default=1,
        help="Number of images to generate per prompt.",
    )
    parser.add_argument(
        "--height",
        type=int,
        default=512,
        help="The height in pixels of the generated image.",
    )
    parser.add_argument(
        "--width",
        type=int,
        default=512,
        help="The width in pixels of the generated image.",
    )
    parser.add_argument(
        "--seed",
        type=int,
        default=None,
        help="A seed for reproducible inference.",
    )
    parser.add_argument(
        "--variant",
        type=str,
        default=None,
        help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16",
    )
    parser.add_argument(
        "--uce_path",
        type=str,
        default=None,
        help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16",
    )

    args = parser.parse_args()
    return args


def main():
    args = parse_args()
    
    # 创建输出目录
    os.makedirs(args.output_dir, exist_ok=True)
    
    # 设置日志
    logging.basicConfig(level=logging.INFO)
    
    # 设置设备
    device = "cuda" if torch.cuda.is_available() else "cpu"
    logger.info(f"Using device: {device}")
    
    # 设置随机种子
    if args.seed is not None:
        torch.manual_seed(args.seed)
        if device == "cuda":
            torch.cuda.manual_seed(args.seed)

    # 加载模型
    logger.info("Loading pipeline...")
    pipeline = FluxPipeline.from_pretrained(
        args.pretrained_model_name_or_path,
        torch_dtype=torch.bfloat16 if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else torch.float16,
        variant=args.variant,
    )
    
    
    # Load UCE weights
    if args.uce_path is not None:
        from safetensors.torch import load_file
        uce_weights = load_file(args.uce_path)
        pipeline.transformer.load_state_dict(uce_weights, strict=False)
        print("UCE model loaded successfully")


    # 加载LoRA权重（如果指定了路径）
    if args.lora_weights_path is not None:
        logger.info(f"Loading LoRA weights from {args.lora_weights_path}")
        pipeline.load_lora_weights(args.lora_weights_path, weight_name=args.lora_weight_name)
    
    # 将管道移动到设备
    pipeline = pipeline.to(device)
    
    # 读取CSV文件
    logger.info(f"Reading prompts from {args.prompt_csv}")
    prompts_data = []
    with open(args.prompt_csv, 'r') as f:
        reader = csv.DictReader(f)
        for row in reader:
            prompts_data.append({
                #'image_id': row['image_id'],
                'prompt': row['prompt']
            })
    
    logger.info(f"Found {len(prompts_data)} prompts in CSV file")
    
    # 为每个prompt生成图像
    for idx, data in enumerate(tqdm(prompts_data, desc="Generating images")):
        prompt = data['prompt']
        
        logger.info(f"Generating image for {idx} with prompt: {prompt}")
        
        # 生成图像
        images = pipeline(
            prompt=prompt,
            num_inference_steps=args.num_inference_steps,
            guidance_scale=args.guidance_scale,
            num_images_per_prompt=args.num_images_per_prompt,
            height=args.height,
            width=args.width,
            generator=torch.Generator(device=device).manual_seed(args.seed) if args.seed is not None else None,
        ).images
        
        # 保存图像
        for i, image in enumerate(images):
            # 如果只需要一张图像，使用image_id作为文件名
            # 如果需要多张图像，添加索引后缀
            if args.num_images_per_prompt == 1:
                image_path = os.path.join(args.output_dir, f"{idx:04d}.png")
            else:
                image_path = os.path.join(args.output_dir, f"{idx:04d}_{i+1}.png")
            
            image.save(image_path)
            logger.info(f"Saved image to {image_path}")
    
    logger.info("Inference completed!")


if __name__ == "__main__":
    main()