import os
os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "1"
import re
import cv2
import numpy as np
import drjit as dr
from mitsuba_render import (mi, T, DT, ScalarPoint3f, render_scene, render_scene_drjit,
                            create_scene_drjit, relighting_scene_drjit, create_scene_dict)
mi.set_variant('cuda_ad_rgb')
from materials import ProceduralMaterial, UniformMaterial
from tqdm import tqdm
import torch
from PIL import Image
import torch.nn.functional as F
from diffmat.optim.descriptor import TextureDescriptor
from utils import (get_obj_files_list, load_image_as_tensor, srgb_to_linear, linear_to_srgb,
                   linear_to_srgb_drjit, clean_tensor_for_gamma, save_masked_image)
from config import HDR_ROOT_DIR, MATERIAL_ROOT_DIR, TRANSFER_DATA_DIR, MESH_ROOT_DIR
import json
from torchvision.utils import save_image
from mitsuba.util import write_bitmap
from criterion import get_loss_fn_drjit, total_variation_loss_drjit
from alignment import align_to_mask


def create_texture_bitmap(texture_path, uv_scale=(1.0, 1.0), uv_rotation=0.0, raw=False, uv_flip=True):
    """创建Mitsuba位图纹理，支持UV缩放和旋转"""
    # 计算UV变换矩阵
    scale_x, scale_y = uv_scale

    # UV变换矩阵 (使用mi.ScalarTransform3f创建2D变换)
    transform = T().scale([scale_x, scale_y, 1.0]).rotate([0, 0, 1], uv_rotation)
    if uv_flip:
        transform = transform.scale([1, -1, 1])

    return {
        'type': 'bitmap',
        'filename': texture_path,
        'raw': raw,
        'to_uv': transform
    }


def parse_material_pair_dict(material_pair_dict, train_resolution):
    parsed_material_dict = {}
    for key, value in material_pair_dict.items():
        if value is None:
            parsed_material_dict[key] = {}
            parsed_material_dict[key]['original_size'] = None
            parsed_material_dict[key]['crop_slice'] = None
            parsed_material_dict[key]['material_category'] =None
            parsed_material_dict[key]['material_name'] = None
            parsed_material_dict[key]['material_param_idx'] = None
            parsed_material_dict[key]['material_path'] = None
            parsed_material_dict[key]['procedural'] = None
        elif value[1] is None:
            parsed_material_dict[key] = {}
            parsed_material_dict[key]['original_size'] = None
            parsed_material_dict[key]['crop_slice'] = None
            # parsed_material_dict[key]['material_category'] = value[2]
            parsed_material_dict[key]['material_category'] = None
            parsed_material_dict[key]['material_name'] = None
            parsed_material_dict[key]['material_param_idx'] = None
            parsed_material_dict[key]['material_path'] = None
            parsed_material_dict[key]['procedural'] = None

            # init_params = class_to_init_brdf_params[value[2]]
            # parsed_material_dict[key]['init_params'] = {
            #     'base_color': init_params['base_color'],
            #     'metallic': init_params['metallic'],
            #     'roughness': init_params['roughness'],
            # }
        else:
            parsed_material_dict[key] = {}
            parsed_material_dict[key]['original_size'] = value[0]
            parsed_material_dict[key]['crop_slice'] = value[1]
            scale = train_resolution[0] / value[0][0]
            y_slice, x_slice = value[1]
            scaled_crop_slice = (slice(int(y_slice.start * scale), int(y_slice.stop * scale)),
                                 slice(int(x_slice.start * scale), int(x_slice.stop * scale)))
            parsed_material_dict[key]['scaled_crop_slice'] = scaled_crop_slice
            mat_category = value[2]
            mat_name = value[3]
            mat_param_idx = int(value[4])

            parsed_material_dict[key]['material_category'] = mat_category
            parsed_material_dict[key]['material_name'] = mat_name
            parsed_material_dict[key]['material_param_idx'] = mat_param_idx

            parsed_material_dict[key]['material_path'] = {
                'albedo': os.path.join(MATERIAL_ROOT_DIR, mat_category, mat_name, "sampled", "basecolor", f'params_{mat_param_idx}.png'),
                'metallic': os.path.join(MATERIAL_ROOT_DIR, mat_category, mat_name, "sampled", "metallic", f'params_{mat_param_idx}.png'),
                'roughness': os.path.join(MATERIAL_ROOT_DIR, mat_category, mat_name, "sampled", "roughness", f'params_{mat_param_idx}.png'),
                'normal': os.path.join(MATERIAL_ROOT_DIR, mat_category, mat_name, "sampled", "normal", f'params_{mat_param_idx}.png')
            }

            parsed_material_dict[key]['procedural'] = {
                'sbs_file_path': os.path.join(MATERIAL_ROOT_DIR, mat_category, mat_name, f"{mat_name}.sbs"),
                'mgt_res': 9,
                'external_input_path': os.path.join(MATERIAL_ROOT_DIR, mat_category, mat_name, "external_input"),
                'ckp_path': os.path.join(MATERIAL_ROOT_DIR, mat_category, mat_name, "sampled",
                                         "param", f"params_{mat_param_idx}.pth"),
            }

    return parsed_material_dict


def load_all_masks(mask_image_path, obj_idx_list, train_resolution):
    # Load mask image
    mask_image = Image.open(mask_image_path).convert('L').resize(train_resolution)
    mask_image = torch.tensor(np.array(mask_image).astype(np.float32) / 255.0,
                              dtype=torch.float32, device='cuda').unsqueeze(-1)
    # load seg image mask
    seg_mask_dict = {}
    for obj_idx in obj_idx_list:
        seg_mask = Image.open(mask_image_path.replace('.png', f'_{obj_idx}.png')).convert('L').resize(train_resolution)
        seg_mask = torch.tensor(np.array(seg_mask).astype(np.float32) / 255.0, dtype=torch.float32, device='cuda')
        seg_mask_dict[f'{obj_idx}'] = seg_mask.unsqueeze(-1)

    return mask_image, seg_mask_dict


def color_loss_masked(image, ref, mask, down_scale=False):
    if down_scale:
        image = F.interpolate(image.unsqueeze(0).permute(0, 3, 1, 2), scale_factor=0.125, mode='bilinear')
        ref = F.interpolate(ref.unsqueeze(0).permute(0, 3, 1, 2), scale_factor=0.125, mode='bilinear')
        mask = F.interpolate(mask.unsqueeze(0).permute(0, 3, 1, 2), scale_factor=0.125, mode='bilinear')
        return F.l1_loss(image * mask, ref * mask)
    else:
        return F.l1_loss(image * mask, ref * mask)


def texture_loss_fn(image, ref, crop_slice, texture_descriptor):
    image_crop = F.interpolate(image[crop_slice].unsqueeze(0).permute(0, 3, 1, 2), size=(512, 512), mode='bilinear')
    ref_crop = F.interpolate(ref[crop_slice].unsqueeze(0).permute(0, 3, 1, 2), size=(512, 512), mode='bilinear')
    return F.l1_loss(texture_descriptor.evaluate(image_crop), texture_descriptor.evaluate(ref_crop))


def acquire_material_parameters(opt_materials):
    material_dict = {}
    for obj_idx, mat in opt_materials.items():
        if isinstance(mat, ProceduralMaterial):
            # 生成带梯度的材质贴图
            maps = mat.evaluate_maps()
            material_dict[obj_idx] = {
                'type': 'procedural',
                'normal': maps['normal'],
                'albedo': maps['albedo'],
                'metallic': maps['metallic'],
                'roughness': maps['roughness'],
            }
        elif isinstance(mat, UniformMaterial):
            values = mat.evaluate_maps()
            material_dict[obj_idx] = {
                'type': 'uniform',
                'albedo': values['albedo'],
                'metallic': values['metallic'],
                'roughness': values['roughness']
            }
        else:
            raise ValueError("Unknown material type")

    return material_dict


@torch.no_grad()
def save_materials(opt_materials, material_eval_dir, cur_step):
    if cur_step >= 0:
        cur_step = f"_{cur_step:04d}"
    else:
        cur_step = ""

    material_path_dict = {}
    for obj_idx, mat in opt_materials.items():
        normal_name = f"obj_{obj_idx}/normal{cur_step}.png"
        albedo_name = f"obj_{obj_idx}/albedo{cur_step}.png"
        metallic_name = f"obj_{obj_idx}/metallic{cur_step}.png"
        roughness_name = f"obj_{obj_idx}/roughness{cur_step}.png"
        os.makedirs(os.path.join(material_eval_dir, f"obj_{obj_idx}"), exist_ok=True)

        material_path_dict[str(obj_idx)] = {}

        # instance of ProceduralMaterial
        if isinstance(mat, ProceduralMaterial):
            # 生成带梯度的材质贴图
            pbr_materials = mat.get_pbr_materials()
            albedo = pbr_materials[0, :3, :, :].permute(1, 2, 0)
            normal = pbr_materials[0, 3:6, :, :].permute(1, 2, 0)
            roughness = pbr_materials[0, 6:7, :, :].permute(1, 2, 0)
            metallic = pbr_materials[0, 7:8, :, :].permute(1, 2, 0)
            normal_img = np.clip(normal.detach().cpu().numpy() * 255, 0, 255).astype(np.uint8)
            albedo_img = np.clip(albedo.detach().cpu().numpy() * 255, 0, 255).astype(np.uint8)
            metallic_img = np.clip(metallic.detach().cpu().numpy() * 255, 0, 255).astype(np.uint8)
            roughness_img = np.clip(roughness.detach().cpu().numpy() * 255, 0, 255).astype(np.uint8)

            Image.fromarray(normal_img).save(os.path.join(material_eval_dir, normal_name))
            Image.fromarray(albedo_img).save(os.path.join(material_eval_dir, albedo_name))
            Image.fromarray(metallic_img[..., 0]).save(os.path.join(material_eval_dir, metallic_name))
            Image.fromarray(roughness_img[..., 0]).save(os.path.join(material_eval_dir, roughness_name))

            material_path_dict[str(obj_idx)] = {
                'base_dir': material_eval_dir,
                'albedo': albedo_name,
                'normal': normal_name,
                'metallic': metallic_name,
                'roughness': roughness_name
            }        # instance of ProceduralMaterial
        elif isinstance(mat, UniformMaterial):
            values = mat.evaluate_maps()
            albedo_img = np.ones((512, 512, 3), dtype=np.uint8) * np.clip(
                values['albedo'].detach().cpu().numpy() * 255, 0, 255).astype(np.uint8)
            metallic_img = np.ones((512, 512), dtype=np.uint8) * np.clip(
                values['metallic'].detach().cpu().numpy() * 255, 0, 255).astype(np.uint8)
            roughness_img = np.ones((512, 512), dtype=np.uint8) * np.clip(
                values['roughness'].detach().cpu().numpy() * 255, 0, 255).astype(np.uint8)

            Image.fromarray(albedo_img).save(os.path.join(material_eval_dir, albedo_name))
            Image.fromarray(metallic_img).save(os.path.join(material_eval_dir, metallic_name))
            Image.fromarray(roughness_img).save(os.path.join(material_eval_dir, roughness_name))
            material_path_dict[str(obj_idx)] = {
                'base_dir': material_eval_dir,
                'albedo': albedo_name,
                'normal': None,
                'metallic': metallic_name,
                'roughness': roughness_name,
            }
        else:
            raise ValueError("Unknown material type")

    return material_path_dict


def validate_iter(scene, params, opt_materials, ref_image, cur_step, output_dir, material_eval_dir=None, spp=256):
    rendered_fw = mi.render(scene, params=params, spp=spp).torch()
    rendered_fw = linear_to_srgb(rendered_fw)

    output_img = np.clip(np.array(rendered_fw.detach().cpu().numpy()) * 255, 0, 255).astype(np.uint8)
    # 创建和ref_image的对比图
    ref_img_np = np.clip(ref_image.detach().cpu().numpy() * 255, 0, 255).astype(np.uint8)
    compare_img = np.concatenate([ref_img_np, output_img], axis=1)
    compare_img = cv2.cvtColor(compare_img, cv2.COLOR_RGB2BGR)
    compare_output_path = os.path.join(output_dir, f"compare_step_{cur_step}.png")
    cv2.imwrite(compare_output_path, compare_img)

    # 保存材质
    if material_eval_dir is not None:
        save_materials(opt_materials, material_eval_dir, cur_step)


def merge_objs(params, opt_materials, save_last_dir, merged_filename="mesh.obj"):
    """
    将所有对象合并到一个OBJ文件中，使用分组功能区分不同对象

    Args:
        params: 包含几何信息的参数字典
        opt_materials: 包含材质信息的字典
        save_last_dir: 保存目录
        merged_filename: 合并后的obj文件名

    Returns:
        merged_obj_path: 合并后的OBJ文件路径
    """
    # 先保存所有材质
    material_path_dict = save_materials(opt_materials, save_last_dir, -1)

    # 合并后的文件路径
    merged_obj_path = os.path.join(save_last_dir, merged_filename)
    merged_mtl_path = os.path.join(save_last_dir, merged_filename.replace('.obj', '.mtl'))

    # 索引偏移量
    vertex_offset = 0

    # 收集所有材质信息
    all_materials = {}

    with open(merged_obj_path, 'w') as merged_obj:
        # 创建mtllib引用
        merged_mtl_name = os.path.basename(merged_mtl_path)
        merged_obj.write(f"mtllib {merged_mtl_name}\n\n")

        # 处理每个对象
        for obj_idx, mat in opt_materials.items():
            obj_key = f'obj_{obj_idx}'

            # 获取几何数据
            vertices = params[f"{obj_key}.vertex_positions"].torch().reshape(-1, 3).cpu().numpy()
            normals = params[f"{obj_key}.vertex_normals"].torch().reshape(-1, 3).cpu().numpy()
            uv = params[f"{obj_key}.vertex_texcoords"].torch().reshape(-1, 2).cpu().numpy()
            faces = params[f"{obj_key}.faces"].torch().reshape(-1, 3).cpu().numpy()

            # 获取材质信息
            material_paths = material_path_dict[str(obj_idx)]

            if isinstance(mat, ProceduralMaterial):
                uv_scale, uv_rotation = mat.get_scale().item(), mat.get_rotate().item()
            else:
                uv_scale, uv_rotation = 1.0, 0.0

            uv_rotation = np.rad2deg(uv_rotation)

            base_dir = material_paths['base_dir']
            albedo_path = material_paths['albedo']
            metallic_path = material_paths['metallic']
            roughness_path = material_paths['roughness']
            normal_path = material_paths['normal']

            # 存储材质信息
            material_name = f"mat_{obj_idx}"
            all_materials[material_name] = {
                "uv_scale": uv_scale,
                "uv_rotation": uv_rotation,
                "albedo_path": albedo_path,
                "metallic_path": metallic_path,
                "roughness_path": roughness_path,
                "normal_path": normal_path
            }

            # 开始一个新组
            merged_obj.write(f"g obj_{obj_idx}\n")

            # 写入顶点数据
            for v in vertices:
                merged_obj.write(f"v {v[0]} {v[1]} {v[2]}\n")

            for vt in uv:
                merged_obj.write(f"vt {vt[0]} {vt[1]}\n")

            for vn in normals:
                merged_obj.write(f"vn {vn[0]} {vn[1]} {vn[2]}\n")

            # 指定材质
            merged_obj.write(f"\nusemtl {material_name}\n\n")

            # 写入面，注意索引从1开始，并且需要加上偏移量
            for face in faces:
                # 为面的每个顶点写入索引（顶点/纹理/法线）
                f1 = face[0] + 1 + vertex_offset
                f2 = face[1] + 1 + vertex_offset
                f3 = face[2] + 1 + vertex_offset

                # 假设纹理和法线索引与顶点索引相同，这与write_mesh函数的逻辑一致
                merged_obj.write(f"f {f1}/{f1}/{f1} {f2}/{f2}/{f2} {f3}/{f3}/{f3}\n")

            # 更新顶点偏移量（纹理和法线与顶点共用同一索引，所以只需要一个偏移量）
            vertex_offset += len(vertices)

            # 添加空行
            merged_obj.write("\n")

    write_scale = True

    # 创建合并的MTL文件
    with open(merged_mtl_path, 'w') as merged_mtl:
        for material_name, mat_info in all_materials.items():
            merged_mtl.write(f"newmtl {material_name}\n")

            # 添加UV比例和旋转作为自定义参数
            merged_mtl.write(f"uv_scale {mat_info['uv_scale']}\n")
            merged_mtl.write(f"uv_rotation {mat_info['uv_rotation']}\n")

            # 写入纹理映射
            albedo_path = mat_info['albedo_path']
            if albedo_path:
                if write_scale:
                    merged_mtl.write(f"map_Kd -s {mat_info['uv_scale']} {mat_info['uv_scale']} 1 {albedo_path}\n")
                else:
                    merged_mtl.write(f"map_Kd {albedo_path}\n")

            metallic_path = mat_info['metallic_path']
            if metallic_path:
                if write_scale:
                    merged_mtl.write(f"map_Pm -s {mat_info['uv_scale']} {mat_info['uv_scale']} 1 {metallic_path}\n")
                else:
                    merged_mtl.write(f"map_Pm {metallic_path}\n")

            roughness_path = mat_info['roughness_path']
            if roughness_path:
                if write_scale:
                    merged_mtl.write(f"map_Pr -s {mat_info['uv_scale']} {mat_info['uv_scale']} 1 {roughness_path}\n")
                else:
                    merged_mtl.write(f"map_Pr {roughness_path}\n")

            normal_path = mat_info['normal_path']
            if normal_path:
                if write_scale:
                    merged_mtl.write(f"map_bump -s {mat_info['uv_scale']} {mat_info['uv_scale']} 1 {normal_path}\n")
                else:
                    merged_mtl.write(f"map_bump {normal_path}\n")

            merged_mtl.write("\n")

    print(f"Successfully merged {len(opt_materials)} objects into {merged_obj_path}")
    return merged_obj_path


def validate_iter_drjit(mts_scene, opt_params, save_path, return_img=False):

    image_render = render_scene_drjit(mts_scene, opt_params, spp=512)
    image_render = linear_to_srgb_drjit(image_render)
    save_image(image_render.torch().permute(2, 0, 1).unsqueeze(0), save_path)

    if return_img:
        return image_render


def inverse_rendering_func(mesh_dir, camera_dict, material_pair_dict, output_dir, seamless_texture_paths,
                      seamless_texture_idx,
                      mesh_translation, mesh_scale, mesh_rotation, mask_image_path, target_image_path,
                      train_resolution=(512, 512), fw_spp=256, bw_spp=64, learning_rate=0.01, max_steps=500):
    multi_stage = True

    # 获取所有obj文件
    obj_files = get_obj_files_list(mesh_dir)
    obj_idx_list = [int(re.search(r'\d+', os.path.basename(obj_file)).group(0)) for obj_file in obj_files]
    # 解析材质对应字典, 创建材质文件夹
    parsed_material_dict = parse_material_pair_dict(material_pair_dict, train_resolution)
    material_eval_dir = os.path.join(output_dir, "material_eval")
    os.makedirs(material_eval_dir, exist_ok=True)
    init_hdr_path = f"{HDR_ROOT_DIR}/uniform.exr"

    # 检查是否使用seamless纹理
    use_seamless_textures = seamless_texture_paths and len(seamless_texture_paths) > 0
    if use_seamless_textures:
        print("Using seamless textures for material optimization")
        uv_scale = [2.0, 2.0]  # 使用原始UV缩放
    else:
        print("Running optimization without seamless textures")
        uv_scale = [1.0, 1.0]  # 使用原始缩放以保持纹理位置

    # 创建场景字典
    scene_dict, opt_param_keys, opt_materials = create_scene_dict(
        obj_files, obj_idx_list, camera_dict, parsed_material_dict, init_hdr_path, train_resolution,
        uv_scale=uv_scale, uv_rotation=0.0,
        translation=mesh_translation, scale=mesh_scale, rotation=mesh_rotation,
        seamless_texture_paths=seamless_texture_paths, seamless_texture_idx=seamless_texture_idx)
    # 加载场景和参数
    scene = mi.load_dict(scene_dict)
    sc_params = mi.traverse(scene)

    # 加载mask图像
    mask_image, seg_mask_dict = load_all_masks(mask_image_path, obj_idx_list, train_resolution)

    # 所有需要优化的参数
    torch_params_list = []
    for key, value in opt_materials.items():
        if value is not None:
            torch_params_list.extend(list(value.get_parameters()))


    if multi_stage:
        target_image = load_image_as_tensor(target_image_path, train_resolution)
        target_image = target_image.squeeze(0).permute(1, 2, 0).cuda()
        opt_envmap = torch.ones(256, 512, 3, device='cuda', dtype=torch.float32, requires_grad=True)
        optimizer_envmap = torch.optim.Adam([opt_envmap], lr=learning_rate)
        multi_stage_pair = (srgb_to_linear(target_image), opt_envmap)
        # multi_stage_pair = (target_image, opt_envmap)
    else:
        multi_stage_pair = None
        optimizer_envmap = None

    training_pairs = [multi_stage_pair]

    # 优化器
    optimizer = torch.optim.Adam(torch_params_list, lr=learning_rate)

    # VGG metric
    texture_descriptor = TextureDescriptor(device='cuda')

    # 根据是否使用seamless纹理来决定UV缩放优化策略
    if use_seamless_textures:
        # 使用seamless纹理时，进行完整的UV缩放优化
        uv_scale_list = [2.0, 4.0, 6.0, 8.0]
        uv_rotation_list = [-90, -45, 0.0, 45.0, 90.0]
        uv_rotation_rad_list = [np.deg2rad(rot) for rot in uv_rotation_list]
        print("Start to find best initial UV scale with seamless textures...")
    else:
        # 不使用seamless纹理时，使用更保守的缩放参数以保持纹理位置
        uv_scale_list = [0.5, 1.0, 1.5, 2.0]
        uv_rotation_list = [-45, 0.0, 45.0]
        uv_rotation_rad_list = [np.deg2rad(rot) for rot in uv_rotation_list]
        print("Start to find best initial UV scale without seamless textures...")
    
    for obj_idx, mat in opt_materials.items():
        pbar = tqdm(total=len(uv_scale_list) * len(uv_rotation_rad_list))
        if isinstance(mat, ProceduralMaterial):
            best_scale = 1.0
            best_scale_rad = 0.0
            best_loss = float('inf')
            for scale in uv_scale_list:
                for uv_rotation_rad in uv_rotation_rad_list:
                    mat.set_scale(scale)
                    mat.set_rotate(uv_rotation_rad)
                    ref_image, emitter_data = training_pairs[-1]
                    with torch.no_grad():
                        material_dict = acquire_material_parameters(opt_materials)
                        rendered_image = render_scene(scene, sc_params, material_dict, emitter_data, spp=bw_spp)
                        rendered_image = linear_to_srgb(rendered_image)
                        crop_slice = parsed_material_dict[f'{obj_idx}']['scaled_crop_slice']
                        loss = texture_loss_fn(rendered_image, linear_to_srgb(ref_image), crop_slice, texture_descriptor).item()
                    if loss < best_loss:
                        best_loss = loss
                        best_scale = scale
                        best_scale_rad = uv_rotation_rad
                    pbar.update(1)

            mat.set_scale(best_scale)
            mat.set_rotate(best_scale_rad)
            print(f"obj_{obj_idx} best scale: {best_scale}, best rotation: {np.rad2deg(best_scale_rad)}")

            ref_image, emitter_data = training_pairs[0]
            with torch.no_grad():
                material_dict = acquire_material_parameters(opt_materials)
                rendered_image = render_scene(scene, sc_params, material_dict, emitter_data, spp=bw_spp)
                rendered_image = linear_to_srgb(rendered_image)
                crop_slice = parsed_material_dict[f'{obj_idx}']['scaled_crop_slice']
                cropped_rendered_image = (rendered_image[crop_slice].detach().cpu().numpy() * 255).astype(np.uint8)
                cropped_ref_image = (linear_to_srgb(ref_image)[crop_slice].detach().cpu().numpy() * 255).astype(np.uint8)
                compare_crop = np.concatenate([cropped_ref_image, cropped_rendered_image], axis=1)
                compare_crop = cv2.cvtColor(compare_crop, cv2.COLOR_RGB2BGR)
                cv2.imwrite(os.path.join(output_dir, f"crop_{obj_idx}_scale_{best_scale}.png"), compare_crop)

    print("Start to optimize...")

    # 训练循环
    pbar = tqdm(total=max_steps)
    n = len(training_pairs)
    loss_record = []
    # 优化循环
    cur_step = 0
    while cur_step < max_steps:
        torch.cuda.empty_cache()
        optimizer.zero_grad()

        training_stage = 2
        ref_image, emitter_data = multi_stage_pair
        emitter_data = emitter_data.clamp(0.0)

        material_dict = acquire_material_parameters(opt_materials)

        rendered_image = render_scene(scene, sc_params, material_dict, emitter_data, spp=bw_spp)
        # 清理张量并应用gamma校正
        rendered_image = clean_tensor_for_gamma(rendered_image)


        # # 计算总体损失
        # if training_stage == 1:
        #     loss = color_loss_masked(rendered_image, ref_image, mask_image, True)
        #     loss.backward()
        #
        #     total_loss = loss.item()
        #     total_material_loss = 0.0
        #     total_color_loss = 0.0
        #     total_envmap_loss = loss.item()
        #     loss_record.append((cur_step, total_loss, total_material_loss, total_color_loss, total_envmap_loss))
        # else:
        # color_loss = color_loss_masked(rendered_image, ref_image, mask_image, True)
        # color_loss = color_loss_masked(rendered_image, ref_image, mask_image, False)
        color_loss = color_loss_masked(rendered_image, ref_image, mask_image, True)
        # color_loss = 0.0
        # material_loss = masked_texture_loss_fn(rendered_image, ref_image, mask_image, texture_descriptor)
        material_loss = 0.0
        for key, value in parsed_material_dict.items():
            if value['procedural'] is not None:
                material_loss += texture_loss_fn(rendered_image, ref_image,
                                                 value['scaled_crop_slice'], texture_descriptor)
                # material_loss += masked_texture_loss_fn(rendered_image, ref_image, seg_mask_dict[key],
                #                                         texture_descriptor)
            else:
                material_loss += color_loss_masked(rendered_image, ref_image, seg_mask_dict[key], False)
        loss = color_loss + material_loss
        loss.backward()

        total_loss = loss.item()
        total_material_loss = material_loss.item()
        total_color_loss = color_loss.item() if color_loss != 0.0 else 0.0
        total_envmap_loss = 0.0
        loss_record.append((cur_step, total_loss, total_material_loss, total_color_loss, total_envmap_loss))

        optimizer_envmap.step()
        optimizer.step()

        # 打印进度和保存结果
        if cur_step % 10 == 0:
            tqdm.write(f"**Step {cur_step}, "
                       f"Learning rate: {learning_rate:.6f}, "
                       f"Loss: {total_loss:.6f}, "
                       f"Material Loss: {total_material_loss:.6f}, "
                       f"Color Loss: {total_color_loss:.6f}, "
                       f"Envmap Loss: {total_envmap_loss:.6f}.")

            # 可选：保存中间结果
            if cur_step % 20 == 0:
                validate_iter(scene, sc_params, opt_materials, linear_to_srgb(ref_image), cur_step,
                              output_dir, material_eval_dir, spp=fw_spp)
                if training_stage >= 1:
                    envmap = multi_stage_pair[1].detach().cpu().numpy()
                    envmap_path = os.path.join(output_dir, f"envmap_step_{cur_step}.exr")
                    cv2.imwrite(envmap_path, envmap)

        pbar.update(1)
        cur_step += 1

    # 保存loss记录到txt, 小数点后6位
    with open(os.path.join(output_dir, "loss_record.txt"), 'w') as f:
        for record in loss_record:
            f.write(f"Step {record[0]}: "
                    f"Learning rate: {learning_rate:.6f},"
                    f"Total loss: {record[1]:.6f}, "
                    f"Material loss: {record[2]:.6f}, "
                    f"Color loss: {record[3]:.6f}\n")

    # 保存最终结果
    save_last_dir = os.path.join(output_dir, "mesh_materials")
    os.makedirs(save_last_dir, exist_ok=True)
    # write_mesh(sc_params, opt_materials, save_last_dir)
    merge_objs(sc_params, opt_materials, save_last_dir)

    if multi_stage:
        # save envmap w, h, 3
        envmap = multi_stage_pair[1].detach().cpu().numpy()
        envmap_path = os.path.join(output_dir, "final_envmap.exr")
        cv2.imwrite(envmap_path, envmap)


class MaterialOptimizer:
    def __init__(self,
                 model_name,
                 model_index,
                 gt_idx,
                 resolution=512,
                 total_iter=200,
                 eval_step=50,
                 loss_type="mae",
                 use_tv_loss=False,
                 tv_loss_weight=0.01,
                 seed=1,
                 device="cuda"):
        """
        Initialize material optimizer.

        Args:
            model_name: Name of the a_model
            model_index: Index of the a_model
            gt_idx: Index of the ground truth image
            texture_paths: Dictionary of paths to initial textures
            resolution: Resolution for rendering and textures
            total_iter: Total number of optimization iterations
            eval_step: Steps between evaluations
            loss_type: Type of loss function to use
            use_tv_loss: Whether to use total variation loss
            tv_loss_weight: Weight for TV loss
            seed: Random seed
            device: Device to run on
        """
        self.model_name = model_name
        self.model_index = model_index
        self.gt_idx = gt_idx
        self.res = resolution
        self.total_iter = total_iter
        self.eval_step = eval_step
        self.loss_type = loss_type
        self.use_tv_loss = use_tv_loss
        self.tv_loss_weight = tv_loss_weight
        self.seed = seed
        self.device = device

        # Setup output directory
        self.output_dir = f"exps/{model_name}-{model_index}-{gt_idx}"
        os.makedirs(self.output_dir, exist_ok=True)

        # Set random seed
        self._set_seed(seed)

        # Get paths
        self.obj_file_path = f"{MESH_ROOT_DIR}/{model_name}/{model_index}"
        self.input_dir = f"{TRANSFER_DATA_DIR}/{model_name}-{model_index}"
        self.mask_image_path = f"{self.input_dir}/mask.png"
        self.view_dict_path = f"{self.input_dir}/best_view.json"
        
        # 自动检测图片格式
        self.gt_image_path = self._find_image_path(gt_idx)

        # Set up loss function
        self.loss_fn = get_loss_fn_drjit(loss_type)

        # Set up regression for image denoising


        self.aov_integrator = mi.load_dict({
            'type': 'aov',
            'aovs': 'albedo:albedo, normals:sh_normal, dd.y:depth'
        })

    def _find_image_path(self, gt_idx):
        """
        自动检测图片格式，支持jpg和png格式
        
        Args:
            gt_idx: 图片索引
            
        Returns:
            str: 图片文件路径
            
        Raises:
            FileNotFoundError: 如果找不到对应格式的图片文件
        """
        # 支持的图片格式列表（按优先级排序）
        image_extensions = ['.jpg', '.jpeg', '.png']
        
        for ext in image_extensions:
            image_path = f"{self.input_dir}/{gt_idx}{ext}"
            if os.path.exists(image_path):
                print(f"Found image: {image_path}")
                return image_path
        
        # 如果都找不到，抛出异常
        raise FileNotFoundError(f"No image found for gt_idx={gt_idx} in {self.input_dir}. "
                              f"Tried extensions: {image_extensions}")

    def _set_seed(self, seed):
        """
        Set random seeds for reproducibility.

        Args:
            seed: Seed value
        """
        torch.manual_seed(seed)
        np.random.seed(seed)
        if torch.cuda.is_available():
            torch.cuda.manual_seed_all(seed)

    def _load_target_image(self, tensorxf=False):
        """
        Load and process the target image.

        Returns:
            torch.Tensor: Processed target image
        """
        # 检查图片文件是否存在
        if not os.path.exists(self.gt_image_path):
            raise FileNotFoundError(f"Ground truth image not found: {self.gt_image_path}")
        
        # Load ground truth image
        image_gt_np = np.array(Image.open(self.gt_image_path).convert("RGB").resize((self.res, self.res)),
                               dtype=np.float32) / 255.0
        image_gt = torch.from_numpy(image_gt_np)

        # Load target mask and align gt image to mask
        mask_image = np.array(Image.open(self.mask_image_path).convert("L").resize((self.res, self.res)),
                              dtype=np.float32) / 255.0
        mask_image = (mask_image > 0.5).astype(np.float32)
        mask_image = torch.from_numpy(mask_image).unsqueeze(0).unsqueeze(0)
        aligned_image = align_to_mask(image_gt.permute(2, 0, 1).unsqueeze(0), self.gt_image_path, mask_image)
        image_gt = aligned_image.squeeze(0).permute(1, 2, 0).contiguous().to(self.device)

        # Save the gt image
        save_image(image_gt.permute(2, 0, 1).unsqueeze(0), os.path.join(self.output_dir, "gt_image.png"))

        if tensorxf:
            image_gt = mi.TensorXf(image_gt)

        return image_gt

    def _setup_scene(self):
        """
        Set up the Mitsuba scene.

        Returns:
            tuple: Scene, parameters, object file list, and object index list
        """
        # Mitsuba config
        mesh_translation = [0, 0, 0]
        mesh_scale = 1.0
        mesh_rotation = [90, 90, 0]
        camera_dict = json.load(open(self.view_dict_path))
        self.camera_dict = camera_dict

        mts_scene, opt_param_keys, obj_files, obj_idx_list = create_scene_drjit(
            self.obj_file_path, mesh_translation, mesh_scale, mesh_rotation,
            camera_dict, resolution=(self.res, self.res))

        opt_params = mi.traverse(mts_scene)

        return mts_scene, opt_params, opt_param_keys, obj_files, obj_idx_list

    def _create_sensor(self, camera_position, camera_target, camera_up):
        return mi.load_dict(
        {
            'type': 'orthographic',
            'near_clip': 0.01,
            'far_clip': 100,
            'to_world': T().look_at(
                origin=ScalarPoint3f(camera_position),
                target=ScalarPoint3f(camera_target),
                up=ScalarPoint3f(camera_up)
            ) @ T().scale(ScalarPoint3f([1, 1, 1])),

            'film': {
                'type': 'hdrfilm',
                'width': self.res,
                'height': self.res,
                'filter': {'type': 'gaussian'},
            },

        })

    def optimize_drjit(self):
        """
        Run the material optimization process.
        """
        # Load target image
        image_gt = self._load_target_image(True)

        # Setup scene
        mts_scene, opt_params, opt_param_keys, obj_files, obj_idx_list = self._setup_scene()

        use_camera_jitter = True
        camera_position = self.camera_dict['camera_position']
        camera_target = self.camera_dict['camera_target']
        camera_up = self.camera_dict['camera_up']

        # Setup optimizer
        optimizer = mi.ad.Adam(lr=0.01)

        for key in opt_param_keys:
            optimizer[key] = opt_params[key]
        opt_params.update(optimizer)


        losses = []

        # Main optimization loop
        for it in tqdm(range(self.total_iter)):
            seed_f = np.random.randint(2 ** 31)
            seed_b = np.random.randint(2 ** 31)

            if use_camera_jitter and it % 2 == 0:
                new_campos = camera_position + np.random.uniform(-0.01, 0.01, size=3)
                sensor = self._create_sensor(new_campos, camera_target, camera_up)
            else:
                sensor = 0

            # Render with two different random seeds for unbiased loss
            diff_image = render_scene_drjit(mts_scene, opt_params, sensor=sensor,
                                            spp=64, spp_grad=64, seed=seed_f, seed_grad=seed_b)
            diff_image = linear_to_srgb_drjit(diff_image)

            loss = self.loss_fn(diff_image, image_gt)

            # # Add TV loss if enabled
            if self.use_tv_loss:
                loss += self.tv_loss_weight * total_variation_loss_drjit(diff_image)

            dr.backward(loss)

            # Optimization step
            optimizer.step()
            for key in opt_param_keys:
                optimizer[key] = dr.clamp(optimizer[key], 0.0, 1.0)
            opt_params.update(optimizer)

            losses.append(loss)

            tqdm.write(f"Iteration {it}, Loss: {loss.torch():.6f}")
            # Evaluation step
            if it % self.eval_step == 0:
                tqdm.write(f"Evaluating iteration {it}...")
                validate_iter_drjit(mts_scene, opt_params, os.path.join(self.output_dir, f"iter_{it}.png"))

        # Save final render and textures
        self._save_final_results_drjit(mts_scene, opt_params, obj_idx_list, opt_param_keys)

    def _save_final_results_drjit(self, mts_scene, opt_params, obj_idx_list, opt_param_keys, relighting=False):
        # Save final render
        final_img = validate_iter_drjit(mts_scene, opt_params, os.path.join(self.output_dir, "final_render.png"),
                                                      return_img=True)

        # Save masked images
        print("Saving masked images...")
        save_masked_image(mask_dir=self.input_dir, output_dir=self.output_dir, image=final_img)

        # Run relighting
        if relighting:
            print("Running relighting...")
            self._run_relighting_drjit(mts_scene, opt_params)

        # Save optimized textures
        print("Saving optimized textures...")
        self._save_optimized_textures(opt_params, opt_param_keys, obj_idx_list)

        print(f"All results saved to {self.output_dir}")

    def _run_relighting_drjit(self, mts_scene, opt_params):
        """
        Run relighting with different environment maps.

        Args:
            mts_scene: Mitsuba scene
            opt_params: Scene parameters
            material_dict: Dictionary of material objects
            obj_idx_list: List of object indices
        """
        # Define paths to environment maps
        relighting_hdr_paths = [
            f"{HDR_ROOT_DIR}/indoor_01.exr",
            f"{HDR_ROOT_DIR}/indoor_02.exr",
            f"{HDR_ROOT_DIR}/outdoor_01.exr",
            f"{HDR_ROOT_DIR}/uniform.exr"
        ]
        for hdr_path in relighting_hdr_paths:
            print(f"Relighting with HDR: {hdr_path}")
            emitter_data = mi.TensorXf(mi.Bitmap(hdr_path))
            image_relight = relighting_scene_drjit(mts_scene, opt_params, emitter_data, spp=512)
            image_relight = linear_to_srgb_drjit(image_relight).torch()
            relighting_output_path = os.path.join(self.output_dir,
                                                  f"relighting_{os.path.basename(hdr_path).split('.')[0]}.png")
            save_image(image_relight.permute(2, 0, 1).unsqueeze(0), relighting_output_path)

    def _reg_textures(self, opt_params, opt_param_keys, obj_idx_list):
        # use tv loss to reg textures
        tv_loss = 0.0
        for key in opt_param_keys:
            if "normalmap" in key or "base_color" in key or "roughness" in key or "metallic" in key:
                tv_loss += total_variation_loss_drjit(opt_params[key])

        return tv_loss / len(obj_idx_list)


    def _save_optimized_textures(self, opt_params, opt_param_keys, obj_idx_list):
        """
        Save optimized texture maps.

        Args:
            opt_params: Scene parameters
            opt_param_keys: Parameter keys
            obj_idx_list: List of object indices
        """
        # Create directory for textures
        texture_output_dir = os.path.join(self.output_dir, "optimized_textures")
        os.makedirs(texture_output_dir, exist_ok=True)

        # Texture names
        texture_names = ["albedo", "metallic", "roughness", "normal"]

        print("Saving optimized textures...")
        for obj_idx in obj_idx_list:
            obj_key = f"obj_{obj_idx}"

            for texture_name in texture_names:
                # Build parameter key based on texture type
                if texture_name == "normal":
                    param_key = f"{obj_key}.bsdf.brdf_0.normalmap.data"
                elif texture_name == "albedo":
                    param_key = f"{obj_key}.bsdf.brdf_0.nested_bsdf.base_color.data"
                elif texture_name == "metallic":
                    param_key = f"{obj_key}.bsdf.brdf_0.nested_bsdf.metallic.data"
                elif texture_name == "roughness":
                    param_key = f"{obj_key}.bsdf.brdf_0.nested_bsdf.roughness.data"
                else:
                    param_key = None

                if param_key in opt_param_keys:
                    # Get texture data
                    texture_data = np.array(opt_params[param_key]).clip(0, 1)

                    # Process based on texture type
                    if texture_name == "albedo":
                        # sRGB conversion
                        texture_data = np.power(texture_data, 1 / 2.2)

                    # Save as image
                    texture_path = os.path.join(texture_output_dir, f"{obj_idx}_{texture_name}.png")
                    write_bitmap(texture_path, texture_data)
                    print(f"Saved {texture_name} texture for object {obj_idx}")


