import os
import cv2
import numpy as np
from PIL import Image
from api_generate import seamless_albedo_generate
from utils import load_image_as_tensor, save_masked_image


def crop_material_region(final_render_path, mask_path, output_dir, obj_idx):
    """
    根据mask裁切final_render中的材质区域
    将mask区域以外的部分设为透明
    
    Args:
        final_render_path: final_render.png的路径
        mask_path: 对应物体的mask路径
        output_dir: 输出目录
        obj_idx: 物体索引
    
    Returns:
        str: 裁切后的图片路径
    """
    # 读取final_render和mask
    final_render = cv2.imread(final_render_path)
    final_render = cv2.cvtColor(final_render, cv2.COLOR_BGR2RGB)
    
    mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
    
    # 确保尺寸一致
    if final_render.shape[:2] != mask.shape[:2]:
        mask = cv2.resize(mask, (final_render.shape[1], final_render.shape[0]))
    
    # 创建RGBA图像（添加透明通道）
    rgba_image = np.zeros((final_render.shape[0], final_render.shape[1], 4), dtype=np.uint8)
    rgba_image[:, :, :3] = final_render  # RGB通道
    
    # 应用mask：mask区域设为不透明（alpha=255），其他区域设为透明（alpha=0）
    mask_bool = mask > 127  # 二值化mask
    rgba_image[:, :, 3] = mask_bool.astype(np.uint8) * 255  # Alpha通道
    
    # 找到mask的边界框
    coords = np.where(mask_bool)
    if len(coords[0]) == 0:
        print(f"Warning: No valid mask region found for obj_{obj_idx}")
        return None
    
    y_min, y_max = coords[0].min(), coords[0].max()
    x_min, x_max = coords[1].min(), coords[1].max()
    
    # 裁切区域（包含透明背景）
    cropped_region = rgba_image[y_min:y_max+1, x_min:x_max+1]
    
    # 保存裁切后的图片（PNG格式支持透明）
    os.makedirs(output_dir, exist_ok=True)
    cropped_path = os.path.join(output_dir, f"cropped_obj_{obj_idx}.png")
    
    cropped_pil = Image.fromarray(cropped_region, 'RGBA')
    cropped_pil.save(cropped_path)
    
    print(f"✅ 已裁切obj_{obj_idx}的材质区域（透明背景）: {cropped_path}")
    return cropped_path


def generate_seamless_textures_for_materials(final_render_path, mask_paths, output_dir, obj_idx_list):
    """
    为所有材质区域生成无缝纹理
    
    Args:
        final_render_path: final_render.png的路径
        mask_paths: 所有mask路径的列表
        output_dir: 输出目录
        obj_idx_list: 物体索引列表
    
    Returns:
        dict: 物体索引到无缝纹理路径的映射
    """
    seamless_dir = os.path.join(output_dir, "seamless")
    os.makedirs(seamless_dir, exist_ok=True)
    
    seamless_texture_paths = {}
    
    for obj_idx, mask_path in zip(obj_idx_list, mask_paths):
        print(f"\n🔄 正在为obj_{obj_idx}生成无缝纹理...")
        
        # 检查是否已存在无缝纹理
        expected_seamless_path = os.path.join(seamless_dir, f"seamless_{obj_idx}.png")
        if os.path.exists(expected_seamless_path):
            # 检查文件大小，确保不是空文件
            if os.path.getsize(expected_seamless_path) > 0:
                seamless_texture_paths[obj_idx] = expected_seamless_path
                print(f"✅ 发现已存在的无缝纹理，跳过API生成: {expected_seamless_path}")
                continue
            else:
                print(f"⚠️ 发现空文件，将重新生成: {expected_seamless_path}")
        
        # 裁切材质区域
        cropped_path = crop_material_region(final_render_path, mask_path, seamless_dir, obj_idx)
        
        if cropped_path is None:
            print(f"❌ 无法裁切obj_{obj_idx}的材质区域，跳过")
            continue
        
        # 生成无缝纹理
        seamless_path = os.path.join(seamless_dir, f"seamless_{obj_idx}.png")
        
        try:
            result_path = seamless_albedo_generate(
                input_path=cropped_path,
                output_dir=seamless_dir,
                mask_path=None,  # 不需要额外mask，因为已经裁切了
                prompt="Generate a seamless, tileable texture inspired by the main material in this image. "
                       "Do NOT include any object outline, stitching, seams or folds. "
                       "Only retain the base material — its color, surface pattern, and texture. "
                       "Output must be 1024x1024, horizontally and vertically tileable; "
                       "contain no lighting gradients or shadows; "
                       "appear like a flat, uniform material swatch with consistent roughness."
            )
            
            if result_path:
                # 检查生成的文件是否存在
                if os.path.exists(result_path):
                    # 重命名为标准格式
                    new_path = os.path.join(seamless_dir, f"seamless_{obj_idx}.png")
                    if result_path != new_path:  # 只有当路径不同时才重命名
                        os.rename(result_path, new_path)
                    seamless_texture_paths[obj_idx] = new_path
                    print(f"✅ 已生成obj_{obj_idx}的无缝纹理: {new_path}")
                else:
                    print(f"❌ 生成的无缝纹理文件不存在: {result_path}")
            else:
                print(f"❌ 生成obj_{obj_idx}的无缝纹理失败")
                
        except Exception as e:
            print(f"❌ 生成obj_{obj_idx}的无缝纹理时出错: {e}")
            continue
    
    return seamless_texture_paths


def get_mask_paths_for_objects(transfer_data_dir, model_name, model_index, obj_idx_list):
    """
    获取指定物体的mask路径
    
    Args:
        transfer_data_dir: 传输数据目录
        model_name: 模型名称
        model_index: 模型索引
        obj_idx_list: 物体索引列表
    
    Returns:
        list: mask路径列表
    """
    # mask文件路径格式：/home/swu/szp/MaterialTransfer/images/chair-002/mask_1.png
    dir_name = os.path.join(transfer_data_dir, f"{model_name}-{model_index}")
    mask_paths = []
    
    for obj_idx in obj_idx_list:
        mask_path = os.path.join(dir_name, f"mask_{obj_idx}.png")
        if os.path.exists(mask_path):
            mask_paths.append(mask_path)
            print(f"✅ 找到obj_{obj_idx}的mask文件: {mask_path}")
        else:
            print(f"❌ 未找到obj_{obj_idx}的mask文件: {mask_path}")
            mask_paths.append(None)
    
    return mask_paths


def auto_generate_seamless_textures(optimizer_output_dir, transfer_data_dir, model_name, model_index, target_obj_list):
    """
    自动生成无缝纹理的主函数
    
    Args:
        optimizer_output_dir: 优化器输出目录
        transfer_data_dir: 传输数据目录
        model_name: 模型名称
        model_index: 模型索引
        target_obj_list: 用户指定的目标物体索引列表（只对这些物体生成纹理）
    
    Returns:
        dict: 物体索引到无缝纹理路径的映射
    """
    final_render_path = os.path.join(optimizer_output_dir, "final_render.png")
    
    if not os.path.exists(final_render_path):
        print(f"❌ final_render.png不存在: {final_render_path}")
        return {}
    
    if not target_obj_list:
        print("❌ 未指定目标物体列表")
        return {}
    
    print(f"🎯 目标物体列表: {target_obj_list}")
    
    # 获取指定物体的mask路径
    mask_paths = get_mask_paths_for_objects(transfer_data_dir, model_name, model_index, target_obj_list)
    
    # 过滤掉不存在的mask
    valid_pairs = [(idx, path) for idx, path in zip(target_obj_list, mask_paths) if path is not None]
    
    if not valid_pairs:
        print("❌ 没有找到有效的mask文件")
        return {}
    
    valid_obj_indices = [pair[0] for pair in valid_pairs]
    valid_mask_paths = [pair[1] for pair in valid_pairs]
    
    print(f"✅ 找到 {len(valid_obj_indices)} 个有效物体: {valid_obj_indices}")
    
    # 生成无缝纹理
    seamless_texture_paths = generate_seamless_textures_for_materials(
        final_render_path, valid_mask_paths, optimizer_output_dir, valid_obj_indices
    )
    
    return seamless_texture_paths 