import os
import re
import cv2
import numpy as np
import mitsuba as mi
import os
import subprocess
import json
import matplotlib.pyplot as plt
from pathlib import Path
import glob

# mi.set_variant("scalar_rgb")
mi.set_variant("cuda_ad_rgb")  # 'scalar_rgb', 'scalar_spectral', 'cuda_ad_rgb', 'llvm_ad_rgb'
from mitsuba import ScalarTransform4f as T
from mitsuba import Transform4f as DT
from mitsuba import Point3f, ScalarPoint3f
from mitsuba.util import write_bitmap, convert_to_bitmap
from glob import glob
from PIL import Image
from best_view import find_best_view, create_object_transform, set_seed


def create_scene_dict(obj_files, obj_idx_list, camera_dict, resolution, translation, scale, rotation):
    """创建包含所有物体的场景字典"""
    # 创建基本场景字典
    scene_dict = {
        'type': 'scene',
        'emitter': {
            'type': 'constant',
            'radiance': {
                'type': 'rgb',
                'value': 1.0,
            }
        },
        'integrator': {
            'type': 'aov',
            'aovs': 'dp:depth,gn:geo_normal',
            'hide_emitters': True,
        },
        # 'sampler': {
        #     'type': 'ldsampler',
        #     'sample_count': 64,
        # },
    }

    scene_dict[f'sensor'] = {
        'type': 'orthographic',
        'near_clip': 0.01,
        'far_clip': 100,
        'to_world': T().look_at(
            origin=ScalarPoint3f(camera_dict['camera_position']),
            target=ScalarPoint3f(camera_dict['camera_target']),
            up=ScalarPoint3f(camera_dict['camera_up'])
        ) @ T().scale(ScalarPoint3f([1, 1, 1])),
        'film': {
            'type': 'hdrfilm',
            'pixel_format': 'rgba',
            'width': resolution[0],
            'height': resolution[1],
            'filter': {'type': 'gaussian'},
        },
    }

    # 添加所有物体，使用diffuse材质
    for idx, obj_file in zip(obj_idx_list, obj_files):
        # 创建对象变换
        object_transform = create_object_transform(translation, scale, rotation)

        # 创建物体字典
        obj_key = f'obj_{idx}'
        scene_dict[obj_key] = {
            'type': 'obj',
            'filename': obj_file,
            'to_world': object_transform,
            'bsdf': {
                'type': 'twosided',
                'nested': {
                    'type': 'diffuse',
                    'reflectance': {
                        'type': 'rgb',
                        'value': [0.5, 0.5, 0.5]
                    }
                }
            }
        }

    return scene_dict

def render_obj_mask(scene, scene_dict, output_dir=None, down_scale=False):
    """渲染物体mask"""
    # 更改场景的渲染器为aov，shape_index
    new_intergrator = scene_dict['integrator']
    new_sensor = scene_dict['sensor']
    new_intergrator['aovs'] = 'si:shape_index'
    new_sensor['film']['filter'] = {'type': 'box'}

    new_intergrator = mi.load_dict(new_intergrator)
    new_sensor = mi.load_dict(new_sensor)

    image = mi.render(scene, sensor=new_sensor, integrator=new_intergrator, spp=1)
    shape_index = np.array(image, dtype=np.uint8) - 1
    obj_mask_list = []
    for i, shape in enumerate(scene.shapes()):
        # get the mask of different shapes
        cur_obj_name = shape.id()  # obj_{idx}
        cur_mesh_idx = int(re.search(r'\d+', cur_obj_name).group(0))
        cur_obj_mask = np.isclose(shape_index, i).astype(np.float32)
        obj_mask_list.append(cur_obj_mask)
        if output_dir is not None:
            save_path = os.path.join(output_dir, f"mask_{cur_mesh_idx}.png")
            cv2.imwrite(save_path, cur_obj_mask * 255)
            if down_scale:
                down_scale_image(save_path, [1024, 1024])
                down_scale_image(save_path, [512, 512])

    # 将obj_mask_list中所有的mask通过不同颜色合并，生成一张Segmentation Mask，并保存
    # Create a colorful segmentation mask by combining all object masks
    segmentation = np.zeros((*shape_index.shape[:-1], 3), dtype=np.uint8)

    # Create a colormap with distinct colors for each object
    num_objects = len(obj_mask_list)
    colormap = np.random.randint(0, 255, size=(num_objects, 3), dtype=np.uint8)
    # Alternatively, use a fixed colormap for consistency between runs
    # colormap = cv2.applyColorMap(np.arange(0, num_objects, dtype=np.uint8), cv2.COLORMAP_JET)

    # Apply each mask with its corresponding color
    for i, mask in enumerate(obj_mask_list):
        # Expand mask to 3 channels
        mask_3channel = np.repeat(mask, 3, axis=2)
        # Set the color for this object
        color = colormap[i]
        # Apply the color to the masked pixels
        segmentation = np.where(mask_3channel == 1, color, segmentation)

    # Save the segmentation mask
    if output_dir is not None:
        segmentation_path = os.path.join(output_dir, "segmentation_mask.png")
        cv2.imwrite(segmentation_path, segmentation)
        if down_scale:
            down_scale_image(segmentation_path, [1024, 1024])
            down_scale_image(segmentation_path, [512, 512])

def render_depth_normal_mask(scene, scene_dict, output_dir=None, spp=128, down_scale=False, best_sensor_dict=None):
    image = mi.render(scene, sensor=0, spp=spp)

    depth = np.array(image[..., 0:1], dtype=np.float32)
    depth_mask = np.array(image[..., 0:1] != 0, dtype=bool)
    mesh_mask = depth_mask.astype(np.uint8)
    depth_min = max(np.min(depth[depth_mask]), 0.1)
    depth_max = max(np.max(depth), 5)
    depth_norm = (depth - depth_min) / (depth_max - depth_min)
    depth = np.where(depth_mask, 1 - depth_norm, 0.0)

    geo_normal = (np.array(image[..., 1:4], dtype=np.float32) + 1) / 2

    if best_sensor_dict is not None:
        shading_normal = world_to_tangent_normals_opengl(geo_normal, scene_dict['sensor'])

    # 保存
    if output_dir is not None:
        save_path = os.path.join(output_dir, f"mask.png")
        cv2.imwrite(save_path, mesh_mask * 255)
        save_path = os.path.join(output_dir, f"depth.png")
        cv2.imwrite(save_path, (depth * 255).astype(np.uint8))
        save_path = os.path.join(output_dir, f"geo_normal.png")
        cv2.imwrite(save_path, (geo_normal * 255).astype(np.uint8))
        if best_sensor_dict is not None:
            save_path = os.path.join(output_dir, f"normal.png")
            cv2.imwrite(save_path, (shading_normal * 255).astype(np.uint8))

        # down sacle all
        if down_scale:
            down_scale_image(os.path.join(output_dir, f"mask.png"), [1024, 1024])
            down_scale_image(os.path.join(output_dir, f"depth.png"), [1024, 1024])
            down_scale_image(os.path.join(output_dir, f"geo_normal.png"), [1024, 1024])
            down_scale_image(os.path.join(output_dir, f"mask.png"), [512, 512])
            down_scale_image(os.path.join(output_dir, f"depth.png"), [512, 512])
            down_scale_image(os.path.join(output_dir, f"geo_normal.png"), [512, 512])

def render_textureless(scene, scene_dict, output_dir=None, spp=128, down_scale=False):
    new_intergrator = mi.load_dict({
        'type': 'path',
        'hide_emitters': True,
    })

    image = mi.render(scene, sensor=0, integrator=new_intergrator, spp=spp)
    image = np.array(image[..., :4], dtype=np.float32)
    image = np.clip(image, 0, 1)
    image = image[..., :3] * image[..., 3:4]
    if output_dir is not None:
        save_path = os.path.join(output_dir, f"initial.png")
        cv2.imwrite(save_path, (image * 255).astype(np.uint8))
        if down_scale:
            down_scale_image(save_path, [1024, 1024])
            down_scale_image(save_path, [512, 512])

def down_scale_image(image_path, target_size):
    output_path = image_path.replace(".png", f"_{target_size[0]}.png")
    original_image = Image.open(image_path)
    original_size = original_image.size
    assert original_size[0] == original_size[1], "Only support square image"
    print(f"Downscale image from {original_size} to {target_size}")
    resized_image = original_image.resize(target_size, Image.LANCZOS)
    resized_image.save(output_path)


def world_to_tangent_normals_opengl(normal_map, best_sensor_dict):
    """
    参数:
    normal_map -- Mitsuba渲染的法线图 (H x W x 3)
    best_sensor_dict -- 包含相机参数的字典
    """
    # 提取相机参数
    camera_position = best_sensor_dict['camera_position']
    camera_target = best_sensor_dict['camera_target']
    camera_up = best_sensor_dict['camera_up']

    camera_transform = T().look_at(
        origin=ScalarPoint3f(camera_position),
        target=ScalarPoint3f(camera_target),
        up=ScalarPoint3f(camera_up)
    )
    camera_inv_transform = np.array(camera_transform.inverse_transpose, dtype=np.float32)

    h, w = normal_map.shape[:2]
    normal_map = normal_map.reshape(-1, 3)

    # 将法线图转换为相机空间
    camera_space_normals = np.dot(normal_map, camera_inv_transform[:3, :3])

    # 创建标准OpenGL法线贴图
    # OpenGL法线贴图中，R通道表示X方向(向右)，G通道表示Y方向(向上)，B通道表示Z方向(向前)
    tangent_space_normals = np.zeros_like(camera_space_normals)

    # 适配坐标系 - 根据您的颜色信息，需要重新映射坐标轴
    # 假设Mitsuba中X是向右，Y是向上，Z是向前
    tangent_space_normals[:, 0] = camera_space_normals[:, 0]  # X轴 -> R通道
    tangent_space_normals[:, 1] = camera_space_normals[:, 1]  # Y轴 -> G通道
    # Z轴强制为正，确保法线指向表面外部
    tangent_space_normals[:, 2] = np.abs(camera_space_normals[:, 2])

    # 归一化法线
    norm = np.sqrt(np.sum(tangent_space_normals ** 2, axis=1))
    tangent_space_normals = tangent_space_normals / norm[:, np.newaxis]

    # 映射到[0,1]范围 (OpenGL法线贴图编码)
    normal_map_rgb = tangent_space_normals * 0.5 + 0.5

    # 重新整形为原始图像尺寸
    normal_map_rgb = normal_map_rgb.reshape(h, w, 3)

    return normal_map_rgb


def run_blender_segmentation(
        obj_dir,
        camera_json_data,  # 这是一个Python字典
        mesh_translation,
        mesh_rotation,
        mesh_scale,
        resolution,
        output_dir
):
    # 将相机参数写入JSON文件
    camera_json_path = os.path.join(output_dir, "best_view.json")
    with open(camera_json_path, 'w') as f:
        json.dump(camera_json_data, f, indent=4)

    # 获取blender_seg_image_mask.py的完整路径
    current_dir = os.path.dirname(os.path.abspath(__file__))
    blender_script_path = os.path.join(current_dir, "blender_render.py")

    # 构建Blender命令
    blender_cmd = [
        "/home/swu/cyh/blender-3.2.2-linux-x64/blender",  # 或指定Blender可执行文件的完整路径
        "--background",  # 无界面模式
        "--python", blender_script_path,
        "--",  # 这之后的参数会传递给Python脚本
        "--obj_dir", obj_dir,
        "--camera_json", camera_json_path,
        "--mesh_translation", str(mesh_translation[0]), str(mesh_translation[1]), str(mesh_translation[2]),
        "--mesh_rotation", str(mesh_rotation[0]), str(mesh_rotation[1]), str(mesh_rotation[2]),
        "--mesh_scale", str(mesh_scale),
        "--resolution", str(resolution[0]), str(resolution[1]),
        "--output_dir", output_dir
    ]

    env = os.environ.copy()
    if 'DISPLAY' in env:
        del env['DISPLAY']

    # 执行Blender命令
    print(f"Running command: {' '.join(blender_cmd)}")
    result = subprocess.run(blender_cmd, check=True, capture_output=True, text=True, env=env)

    print("Blender stdout:")
    print(result.stdout)

    if result.stderr:
        print("Blender stderr:")
        print(result.stderr)

    return result.returncode == 0

def render_best_view_geometry(best_sensor_dict, mesh_path, mesh_translation, mesh_scale, mesh_rotation,
                              resolution=[512, 512], spp=128, output_dir=None):
    set_seed(0)
    # in blender
    objs = sorted(glob(f"{mesh_path}/obj_*.obj"))
    obj_idx_list = [int(re.search(r'\d+', os.path.basename(obj_file)).group(0)) for obj_file in objs]
    blender_rotation = [mesh_rotation[0], mesh_rotation[2], mesh_rotation[1]]
    # 调用blender渲染
    run_blender_segmentation(mesh_path, best_sensor_dict, mesh_translation, blender_rotation, mesh_scale,
                             resolution, output_dir)

    convert_mask_to_segmentation(obj_idx_list, output_dir)

    # 创建包含所有mesh的场景，并加载
    scene_dict = create_scene_dict(objs, obj_idx_list, best_sensor_dict, resolution, mesh_translation, mesh_scale, mesh_rotation)
    scene = mi.load_dict(scene_dict)
    render_textureless(scene, scene_dict, output_dir=output_dir, spp=spp)

    return best_sensor_dict


def controlnet_seg_colormap(num_objects):
    palette = np.asarray([
        [120, 120, 120],
        [180, 120, 120],
        [6, 230, 230],
        [80, 50, 50],
        [4, 200, 3],
        [120, 120, 80],
        [140, 140, 140],
        [204, 5, 255],
        [230, 230, 230],
        [4, 250, 7],
        [224, 5, 255],
        [235, 255, 7],
        [150, 5, 61],
        [120, 120, 70],
        [8, 255, 51],
        [255, 6, 82],
        [143, 255, 140],
        [204, 255, 4],
        [255, 51, 7],
        [204, 70, 3],
        [0, 102, 200],
        [61, 230, 250],
        [255, 6, 51],
        [11, 102, 255],
        [255, 7, 71],
        [255, 9, 224],
        [9, 7, 230],
        [220, 220, 220],
        [255, 9, 92],
        [112, 9, 255],
        [8, 255, 214],
        [7, 255, 224],
        [255, 184, 6],
        [10, 255, 71],
        [255, 41, 10],
        [7, 255, 255],
        [224, 255, 8],
        [102, 8, 255],
        [255, 61, 6],
        [255, 194, 7],
        [255, 122, 8],
        [0, 255, 20],
        [255, 8, 41],
        [255, 5, 153],
        [6, 51, 255],
        [235, 12, 255],
        [160, 150, 20],
        [0, 163, 255],
        [140, 140, 140],
        [250, 10, 15],
        [20, 255, 0],
        [31, 255, 0],
        [255, 31, 0],
        [255, 224, 0],
        [153, 255, 0],
        [0, 0, 255],
        [255, 71, 0],
        [0, 235, 255],
        [0, 173, 255],
        [31, 0, 255],
        [11, 200, 200],
        [255, 82, 0],
        [0, 255, 245],
        [0, 61, 255],
        [0, 255, 112],
        [0, 255, 133],
        [255, 0, 0],
        [255, 163, 0],
        [255, 102, 0],
        [194, 255, 0],
        [0, 143, 255],
        [51, 255, 0],
        [0, 82, 255],
        [0, 255, 41],
        [0, 255, 173],
        [10, 0, 255],
        [173, 255, 0],
        [0, 255, 153],
        [255, 92, 0],
        [255, 0, 255],
        [255, 0, 245],
        [255, 0, 102],
        [255, 173, 0],
        [255, 0, 20],
        [255, 184, 184],
        [0, 31, 255],
        [0, 255, 61],
        [0, 71, 255],
        [255, 0, 204],
        [0, 255, 194],
        [0, 255, 82],
        [0, 10, 255],
        [0, 112, 255],
        [51, 0, 255],
        [0, 194, 255],
        [0, 122, 255],
        [0, 255, 163],
        [255, 153, 0],
        [0, 255, 10],
        [255, 112, 0],
        [143, 255, 0],
        [82, 0, 255],
        [163, 255, 0],
        [255, 235, 0],
        [8, 184, 170],
        [133, 0, 255],
        [0, 255, 92],
        [184, 0, 255],
        [255, 0, 31],
        [0, 184, 255],
        [0, 214, 255],
        [255, 0, 112],
        [92, 255, 0],
        [0, 224, 255],
        [112, 224, 255],
        [70, 184, 160],
        [163, 0, 255],
        [153, 0, 255],
        [71, 255, 0],
        [255, 0, 163],
        [255, 204, 0],
        [255, 0, 143],
        [0, 255, 235],
        [133, 255, 0],
        [255, 0, 235],
        [245, 0, 255],
        [255, 0, 122],
        [255, 245, 0],
        [10, 190, 212],
        [214, 255, 0],
        [0, 204, 255],
        [20, 0, 255],
        [255, 255, 0],
        [0, 153, 255],
        [0, 41, 255],
        [0, 255, 204],
        [41, 0, 255],
        [41, 255, 0],
        [173, 0, 255],
        [0, 245, 255],
        [71, 0, 255],
        [122, 0, 255],
        [0, 255, 184],
        [0, 92, 255],
        [184, 255, 0],
        [0, 133, 255],
        [255, 214, 0],
        [25, 194, 194],
        [102, 255, 0],
        [92, 0, 255],
    ])

    # ramdom select
    colormap = palette[np.random.choice(palette.shape[0], num_objects, replace=False)]
    return colormap


def convert_mask_to_segmentation(obj_idx_list, output_dir):
    mask_list = []
    for idx in obj_idx_list:
        mask_path = os.path.join(output_dir, f"mask_{idx}.png")
        mask = Image.open(mask_path).convert('L')
        mask = np.array(mask, dtype=np.float32) / 255.0
        mask = (mask > 0.5).astype(np.uint8)
        mask_list.append(mask)
        Image.fromarray(mask * 255).save(mask_path)  # 保存二值化后的掩码

    # 在增加一个将所有mask合在一起保存到mask.png
    mask_all = np.zeros_like(mask_list[0])
    for mask in mask_list:
        mask_all += mask
    mask_all = (mask_all > 0).astype(np.uint8)

    # 创建分割图像
    segmentation = np.zeros((*mask_list[0].shape, 3), dtype=np.uint8)
    num_objects = len(obj_idx_list)
    # 这里的colormap有问题，有时候color = colormap[i % 255][0]会list index out of range
    # colormap = cv2.applyColorMap(np.arange(0, 255, 255 // max(1, num_objects - 1), dtype=np.uint8), cv2.COLORMAP_JET)
    colormap = controlnet_seg_colormap(num_objects)

    for i, mask in enumerate(mask_list):
        mask_expanded = np.expand_dims(mask, axis=2)
        mask_3channel = np.repeat(mask_expanded, 3, axis=2)
        color = colormap[i]
        segmentation = np.where(mask_3channel == 1, color, segmentation)

    if output_dir is not None:
        segmentation_path = os.path.join(output_dir, "segmentation_mask.png")
        cv2.imwrite(segmentation_path, segmentation)
        mask_all_path = os.path.join(output_dir, "mask_all.png")
        cv2.imwrite(mask_all_path, mask_all * 255)

    return segmentation


def process_folder_with_best_view(input_folder, output_base_dir, translation=[0, 0, 0], scale=1.0, rotation=[0, 0, 0]):
    """处理输入文件夹下的所有子文件夹"""
    input_path = Path(input_folder)
    output_base_path = Path(output_base_dir)

    if not input_path.exists():
        print(f"错误：输入文件夹 {input_folder} 不存在")
        return

    # 获取所有子文件夹
    subdirs = [d for d in input_path.iterdir() if d.is_dir()]

    if not subdirs:
        print(f"警告：在 {input_folder} 中没有找到子文件夹")
        return

    print(f"找到 {len(subdirs)} 个子文件夹")

    for subdir in subdirs:
        print(f"\n处理子文件夹: {subdir.name}")

        # 检查子文件夹中是否有obj文件
        obj_files = list(subdir.glob("obj_*.obj"))
        if not obj_files:
            print(f"  跳过 {subdir.name}：没有找到obj文件")
            continue

        # 创建输出目录 - 修改为输入目录名_子目录名
        output_subdir_name = f"{input_path.name}-{subdir.name}"
        output_dir = output_base_path / output_subdir_name
        output_dir.mkdir(parents=True, exist_ok=True)

        print(f"  输出目录: {output_dir}")

        try:
            # 1. 找到最佳视角
            print("  查找最佳视角...")
            best_view_dict = find_best_view(
                str(subdir),
                translation,
                scale,
                rotation,
                str(output_dir)
            )

            # 2. 保存最佳视角信息
            best_view_path = output_dir / "best_view.json"
            with open(best_view_path, 'w') as f:
                json.dump(best_view_dict, f, indent=2)

            print(f"  最佳视角已保存到: {best_view_path}")

            # 3. 渲染最佳视角下的图像
            print("  渲染最佳视角图像...")
            success = render_best_view_geometry(
                best_view_dict,
                str(subdir),
                translation,
                scale,
                rotation,
                resolution=[1024, 1024],
                spp=128,
                output_dir=str(output_dir)
            )

            if success:
                print(f"  渲染完成，图像已保存到: {output_dir}")
            else:
                print(f"  渲染失败: {subdir}")

        except Exception as e:
            print(f"  处理 {subdir.name} 时出错: {str(e)}")
            continue

    print(f"\n所有子文件夹处理完成！")

def rerender_from_output_dir(output_dir, mesh_path, mesh_translation, mesh_scale, mesh_rotation, new_output_dir=None):
    best_view_path = output_dir / "best_view.json"
    with open(best_view_path, 'r') as f:
        best_view_dict = json.load(f)
    best_sensor_dict = best_view_dict["best_sensor_dict"]

    if new_output_dir is None:
        # copy best_view.json to new_output_dir
        shutil.copy(best_view_path, new_output_dir / "best_view.json")
        output_dir = new_output_dir

    render_best_view_geometry(best_sensor_dict, mesh_path, mesh_translation, mesh_scale, mesh_rotation,
                              resolution=[1024, 1024], spp=128, output_dir=new_output_dir)



def main():
    """主函数示例"""
    # 示例调用
    # input_folder = "/home/swu/szp/MaterialTransfer/end/1"  # 输入文件夹路径
    input_folder = "/home/swu/cyh/MasterGraduationProject/3rdparty/b_model/mesh/chair/000"  # 输入文件夹路径
    # output_dir = "/home/swu/szp/MaterialTransfer/end/1"  # 输出目录
    output_dir = "/home/swu/szp/MaterialTransfer/a_model/new_view/chair/000"  # 输出目录

    # 模型变换参数（根据需要调整）
    translation = [0, 0, 0]
    scale = 1.0
    rotation = [90, 90, 0]

    # 批处理模式：处理输入文件夹下的所有子文件夹
    # process_folder_with_best_view(input_folder, output_dir, translation, scale, rotation)

    # 重渲染:
    # old_output_dir = "/home/swu/szp/MaterialTransfer/a_model/intput_image"
    # rerender_from_output_dir(old_output_dir, input_folder, translation, scale, rotation, new_output_dir=output_dir)

    # 或者单文件夹模式：处理单个文件夹
    # best_sensor_dict = find_best_view(input_folder, translation, scale, rotation, output_dir)
    os.makedirs(output_dir, exist_ok=True)
    best_view_path = os.path.join(output_dir, "best_view.json")
    with open(best_view_path, "r") as f:
        best_sensor_dict = json.load(f)
    # with open(best_view_path, "w") as f:
    #     json.dump(best_sensor_dict, f)
    render_best_view_geometry(best_sensor_dict, input_folder, translation, scale, rotation, resolution=[1024, 1024], spp=128, output_dir=output_dir)


if __name__ == "__main__":
    main()
