import os
from utils import save_masked_image, linear_to_srgb_drjit, save_material_pair_json, load_material_pair_json
from inverse_rendering_helper import inverse_rendering_func, MaterialOptimizer
from intrinsic_decompose import intrinsic_decomp
from pair_materials import pair_materials
from config import TRANSFER_DATA_DIR
from seamless_generator import auto_generate_seamless_textures


def run_inverse_rendering(examples, seamless_texture_idx_list=None, seamless_texture_list=None, auto_generate=True,
                          target_obj_lists=None):
    """Main function to run material optimization examples.

    Args:
        examples: 实验配置列表
        seamless_texture_idx_list: 手动指定的无缝纹理索引列表
        seamless_texture_list: 手动指定的无缝纹理列表
        auto_generate: 是否使用自动生成模式
        target_obj_lists: 每个实验对应的目标物体列表，格式为 [[obj1, obj2], [obj3], ...]
    """

    # Run all examples
    for i, example in enumerate(examples):
        print(
            f"Running optimization for {example['model_name']}-{example['model_index']}, GT index: {example['gt_idx']}")
        # 设置优化器
        optimizer = MaterialOptimizer(
            model_name=example["model_name"],
            model_index=example["model_index"],
            gt_idx=example["gt_idx"],
            resolution=512,
            total_iter=200,
            eval_step=20,
            loss_type="mae",
            use_tv_loss=False,
            tv_loss_weight=0.01,
            seed=1,
            device="cuda"
        )

        # optimizer.optimize()
        optimizer._setup_scene()
        # 如果不存在final_render,就去渲染一张，这里需要去调用api
        if not os.path.exists(os.path.join(optimizer.output_dir, "final_render.png")):
            optimizer.optimize_drjit()

        dir_name = os.path.join(TRANSFER_DATA_DIR, example["model_name"] + "-" + example["model_index"])
        all_files = os.listdir(os.path.join(TRANSFER_DATA_DIR, dir_name))
        # filter the mask starts with mask_ and end with .png
        obj_mask_paths = [os.path.join(TRANSFER_DATA_DIR, dir_name, f)
                          for f in all_files if f.startswith("mask_") and f.endswith(".png")]

        ##得到法线，粗糙度和金属度，如果没有得到就进行分解
        normal_path = os.path.join(TRANSFER_DATA_DIR, dir_name, "normal.png")
        roughness_image_path = os.path.join(optimizer.output_dir, "final_rgbx", "final_render_roughness.png")
        metallic_image_path = os.path.join(optimizer.output_dir, "final_rgbx", "final_render_metallic.png")
        if not os.path.exists(roughness_image_path):
            input_path = os.path.join(optimizer.output_dir, "final_render.png")
            output_sub_dir = os.path.join(optimizer.output_dir, "final_rgbx")
            intrinsic_decomp(input_path, output_sub_dir, max_side=512, samples=1, steps=50, seed=0, mask_path=None)

        ##分解完后就可以进行材质匹配了
        paired_materials_path = os.path.join(optimizer.output_dir, "paired_materials.json")
        if not os.path.exists(paired_materials_path):
            paired_materials = pair_materials(obj_mask_paths, normal_path,
                                              roughness_image_path, metallic_image_path,
                                              output_path=None, device='cuda')
            save_material_pair_json(paired_materials, paired_materials_path)
        else:
            paired_materials = load_material_pair_json(paired_materials_path)

        mesh_dir = os.path.join(optimizer.obj_file_path)
        camera_dict = optimizer.camera_dict
        material_pair_dict = paired_materials
        output_dir = optimizer.output_dir
        mesh_translation = [0, 0, 0]
        mesh_scale = 1.0
        mesh_rotation = [90, 90, 0]
        mask_image_path = optimizer.mask_image_path
        target_image_path = os.path.join(optimizer.output_dir, "final_render.png")

        # 获取所有可用的物体索引列表
        all_obj_idx_list = [int(os.path.basename(p).split('_')[1].split('.')[0]) for p in obj_mask_paths]
        print(f" 所有可用物体: {all_obj_idx_list}")

        # 获取当前实验的目标物体列表
        if target_obj_lists and i < len(target_obj_lists):
            target_obj_list = target_obj_lists[i]
            print(f" 实验 {i} 的目标物体: {target_obj_list}")
        else:
            # 如果没有指定目标物体列表，使用所有物体
            target_obj_list = all_obj_idx_list
            print(f" 使用所有物体作为目标: {target_obj_list}")

        # 自动生成无缝纹理或使用手动指定的纹理
        if auto_generate:
            print(" 自动生成无缝纹理...")
            seamless_texture_paths = auto_generate_seamless_textures(
                optimizer.output_dir, TRANSFER_DATA_DIR,
                example["model_name"], example["model_index"], target_obj_list
            )

            if seamless_texture_paths:
                seamless_texture_idx_list_auto = list(seamless_texture_paths.keys())
                print(f" 成功生成无缝纹理的物体: {seamless_texture_idx_list_auto}")
            else:
                print(" 未能生成任何无缝纹理，跳过材质传输")
                continue
        else:
            # 使用手动指定的无缝纹理
            if seamless_texture_idx_list and seamless_texture_list and i < len(seamless_texture_idx_list):
                seamless_texture_paths = {}
                for obj_idx in seamless_texture_idx_list[i]:
                    seamless_texture_idx = seamless_texture_list[i][seamless_texture_idx_list[i].index(obj_idx)]
                    seamless_texture_paths[obj_idx] = os.path.join(TRANSFER_DATA_DIR,
                                                                   f"{example['model_name']}-{example['model_index']}",
                                                                   f"seamless_{seamless_texture_idx}.png")
                seamless_texture_idx_list_auto = seamless_texture_idx_list[i]
            else:
                print("❌ 未提供手动指定的无缝纹理配置，跳过材质传输")
                continue

        inverse_rendering_func(mesh_dir, camera_dict, material_pair_dict, output_dir, seamless_texture_paths,
                               seamless_texture_idx_list_auto,
                               mesh_translation, mesh_scale, mesh_rotation, mask_image_path, target_image_path,
                               train_resolution=(512, 512), fw_spp=256, bw_spp=64, learning_rate=0.01, max_steps=200)


if __name__ == "__main__":
    print("=" * 60)
    print("🎯 材质传输系统 - 自动无缝纹理生成")
    print("=" * 60)

    # 定义实验配置
    examples = [
        {"model_name": "bag", "model_index": "005", "gt_idx": 2},

    ]

    seamless_texture_idx_list = [
        [1, 3],
    ]

    seamless_texture_list = [
        [5, 5],
    ]

    # 配置选项
    AUTO_GENERATE = False # 设置为True使用自动生成，False使用手动指定

    # 用户指定的目标物体列表（只对这些物体生成无缝纹理）
    # 每个子列表对应一个实验，列表中的数字表示要处理的物体索引
    target_obj_lists = [
        [0]
    ]

    if AUTO_GENERATE:
        run_inverse_rendering(examples, auto_generate=True, target_obj_lists=target_obj_lists)
    else:
        run_inverse_rendering(examples, seamless_texture_idx_list, seamless_texture_list, auto_generate=False)

    print("=" * 60)
    print("材质传输完成！")
    print("=" * 60)

