import os
import argparse
from pathlib import Path
import torch
import torchvision
from tqdm import tqdm
import numpy as np
from PIL import Image

# 设置OpenEXR支持
os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "1"

from diffusers import DDIMScheduler, DDPMScheduler
from utils import load_exr_image, load_ldr_image
from pipeline_rgbx import StableDiffusionAOVMatEstPipeline


#单张RGB图像分解生成AOV
def process_single_image(
    pipe, 
    image_path, 
    output_dir, 
    seed=42, 
    inference_steps=50, 
    num_samples=1,
    max_side=1000,
    mask_path=None,
):
    print(f"处理图像: {image_path}")
    
    # 确保输出目录存在
    os.makedirs(output_dir, exist_ok=True)
    
    # 设置生成器
    if seed == -1:  # 随机种子
        seed = torch.randint(0, 2147483647, (1,)).item()
    generator = torch.Generator(device="cuda").manual_seed(seed)
    
    # 根据扩展名加载图像
    if image_path.endswith(".exr"):
        photo = load_exr_image(image_path, tonemaping=True, clamp=True).to("cuda")
    elif any(image_path.endswith(ext) for ext in [".png", ".jpg", ".jpeg"]):
        photo = load_ldr_image(image_path, from_srgb=True).to("cuda")
    else:
        print(f"不支持的图像格式: {image_path}")
        return

    if mask_path is not None:
        # 加载掩码图像
        mask = load_ldr_image(mask_path, from_srgb=True).to("cuda")

        # 确保掩码和图像大小相同
        if mask.shape[1:] != photo.shape[1:]:
            mask = torchvision.transforms.Resize((photo.shape[1], photo.shape[2]))(mask)

        # 将掩码应用到图像
        photo = photo * mask
    
    # 保存原始尺寸信息
    old_height = photo.shape[1]
    old_width = photo.shape[2]
    new_height = old_height
    new_width = old_width
    
    # 调整图像大小，保持比例
    ratio = old_height / old_width
    if old_height > old_width:
        new_height = max_side
        new_width = int(new_height / ratio)
    else:
        new_width = max_side
        new_height = int(new_width * ratio)
    
    # 确保宽高是8的倍数
    if new_width % 8 != 0 or new_height % 8 != 0:
        new_width = new_width // 8 * 8
        new_height = new_height // 8 * 8
    
    photo = torchvision.transforms.Resize((new_height, new_width))(photo)
    
    # 定义需要生成的AOV和对应提示
    required_aovs = ["albedo", "normal", "roughness", "metallic"]
    prompts = {
        "albedo": "Albedo (diffuse basecolor)",
        "normal": "Camera-space Normal",
        "roughness": "Roughness",
        "metallic": "Metallicness",
    }
    
    # 获取基本文件名（不带扩展名）
    base_filename = os.path.basename(image_path)
    base_filename = os.path.splitext(base_filename)[0]

    # 生成多个样本
    for sample_idx in range(num_samples):
        sample_dir = output_dir
        if num_samples > 1:
            sample_dir = os.path.join(output_dir, f"sample_{sample_idx}")
            os.makedirs(sample_dir, exist_ok=True)
        
        # 为每个AOV生成图像
        for aov_name in tqdm(required_aovs, desc=f"样本 {sample_idx+1}/{num_samples}"):
            prompt = prompts[aov_name]
            generated_image = pipe(
                prompt=prompt,
                photo=photo,
                num_inference_steps=inference_steps,
                height=new_height,
                width=new_width,
                generator=generator,
                required_aovs=[aov_name],
            ).images[0][0]

            if isinstance(generated_image, Image.Image):
                if aov_name == "metallic":
                    generated_image = np.array(generated_image, dtype=np.float32) / 255
                    max_metallic = generated_image.max()
                    if max_metallic > 0.2:
                        generated_image = 0.2 * generated_image / max_metallic
                    generated_image = Image.fromarray((generated_image * 255).astype(np.uint8))
                elif aov_name == "roughness":
                    generated_image = np.array(generated_image).clip(20, 242)
                    generated_image = Image.fromarray(generated_image)
            
            # 将图像调整回原始大小
            if isinstance(generated_image, torch.Tensor):
                # 如果是tensor，先调整大小，然后转换为PIL图像
                generated_image = torchvision.transforms.Resize((old_height, old_width))(generated_image)
                if generated_image.dim() == 3:
                    # 确保图像是[C, H, W]格式，然后转换为PIL图像
                    pil_image = torchvision.transforms.ToPILImage()(generated_image)
                    pil_image.save(os.path.join(sample_dir, f"{base_filename}_{aov_name}.png"))
                else:
                    print(f"警告: 不支持的张量形状: {generated_image.shape}")
            elif isinstance(generated_image, Image.Image):
                # 如果已经是PIL图像，直接调整大小并保存
                generated_image = generated_image.resize((old_width, old_height), Image.LANCZOS)
                generated_image.save(os.path.join(sample_dir, f"{base_filename}_{aov_name}.png"))
            else:
                print(f"警告: 不支持的图像类型: {type(generated_image)}")
            
            print(f"已保存: {os.path.join(sample_dir, f'{base_filename}_{aov_name}.png')}")

#调用单张图片处理的函数对文件中的所有图片进行处理
def process_directory(pipe, input_dir, output_dir, seed, inference_steps, num_samples, max_side, mask_path=None):
    """处理目录中的所有图像文件"""
    image_extensions = [".exr", ".png", ".jpg", ".jpeg"]
    
    # 获取所有图像文件
    image_files = []
    for ext in image_extensions:
        image_files.extend(list(Path(input_dir).glob(f"*{ext}")))
    
    if not image_files:
        print(f"在 {input_dir} 中未找到支持的图像文件")
        return
    
    print(f"找到 {len(image_files)} 个图像文件")
    
    # 处理每个图像文件
    for img_path in image_files:
        img_output_dir = os.path.join(output_dir, img_path.stem)
        process_single_image(
            pipe,
            str(img_path),
            img_output_dir,
            seed,
            inference_steps,
            num_samples,
            max_side,
            mask_path
        )


def intrinsic_decomp(input_path, output_dir, max_side=512 ,samples=1, steps=50, seed=0, mask_path=None):
    # 加载模型
    print("正在加载模型...")
    pipe = StableDiffusionAOVMatEstPipeline.from_pretrained(
        "/home/swu/cyh/MasterGraduationProject/pretrained_models/rgb-to-x",
        torch_dtype=torch.float16,
    ).to("cuda")

    pipe.scheduler = DDIMScheduler.from_config(
        pipe.scheduler.config, rescale_betas_zero_snr=True, timestep_spacing="trailing"
    )
    pipe.set_progress_bar_config(disable=True)
    print("模型加载完成!")

    if os.path.isfile(input_path):
        # 处理单个文件
        process_single_image(
            pipe,
            input_path,
            output_dir,
            seed,
            steps,
            samples,
            max_side,
            mask_path,
        )
    else:
        print(f"错误: 输入路径 {input_path} 不存在")


def main():
    #设置命令行参数
    parser = argparse.ArgumentParser(description="RGB到内在属性转换工具")
    parser.add_argument("--input", required=True, help="输入图像或图像文件夹的路径")
    parser.add_argument("--output", required=True, help="输出目录路径")
    parser.add_argument("--seed", type=int, default=-1, help="随机种子 (-1表示随机)")
    parser.add_argument("--steps", type=int, default=50, help="推理步数")
    parser.add_argument("--samples", type=int, default=1, help="为每个输入生成的样本数量")
    parser.add_argument("--max_side", type=int, default=1000, help="处理图像的最大边长")
    args = parser.parse_args()
    
    # 获取当前目录
    current_directory = os.path.dirname(os.path.abspath(__file__))
    
    # 加载模型
    print("正在加载模型...")
    pipe = StableDiffusionAOVMatEstPipeline.from_pretrained(
        "/home/swu/cyh/MasterGraduationProject/pretrained_models/rgb-to-x",
        torch_dtype=torch.float16,
        cache_dir=os.path.join(current_directory, "model_cache"),
    ).to("cuda")
    
    pipe.scheduler = DDIMScheduler.from_config(
        pipe.scheduler.config, rescale_betas_zero_snr=True, timestep_spacing="trailing"
    )
    pipe.set_progress_bar_config(disable=True)
    print("模型加载完成!")
    
    # 处理输入
    input_path = args.input
    output_dir = args.output
    
    if os.path.isfile(input_path):
        # 处理单个文件
        process_single_image(
            pipe,
            input_path,
            output_dir,
            args.seed,
            args.steps,
            args.samples,
            args.max_side,
            mask_path=None
        )
    elif os.path.isdir(input_path):
        # 处理目录
        process_directory(
            pipe,
            input_path,
            output_dir,
            args.seed,
            args.steps,
            args.samples,
            args.max_side,
            mask_path=None
        )
    else:
        print(f"错误: 输入路径 {input_path} 不存在")



# 遍历目录 /application/work2/iterative_editing/exps/
#
# 每个子目录应该包含一个名为 final_render.png 的图像
#
# 如果该图像存在，则调用 intrinsic_decomp() 提取其 AOV 属性
#
# 输出结果保存在 final_rgbx/ 目录
if __name__ == "__main__":



    # 输入和输出目录路径
    input_dir = "/home/swu/szp/MaterialTransfer/paper_images/nob/1"
    output_root = "/home/swu/szp/MaterialTransfer/paper_images/nob/"

    # 遍历所有图像文件
    supported_exts = [".png", ".jpg", ".jpeg", ".exr"]
    for file_name in sorted(os.listdir(input_dir)):
        if not any(file_name.lower().endswith(ext) for ext in supported_exts):
            continue  # 跳过非图像文件

        input_path = os.path.join(input_dir, file_name)
        base_name = os.path.splitext(file_name)[0]
        output_dir = os.path.join(output_root, base_name)

        print(f"开始处理: {input_path}")
        intrinsic_decomp(
            input_path=input_path,
            output_dir=output_dir,
            max_side=512,
            samples=1,
            steps=50,
            seed=0,
            mask_path=None
        )



