import bpy
import math
import numpy as np
import os
import argparse
from mathutils import Vector, Matrix
import sys
import json
from glob import glob
import re


def get_args():
    parser = argparse.ArgumentParser(
        description='Generate material masks, surface normals, and depth maps from OBJ files')

    parser.add_argument('--obj_dir', type=str, required=True, help='Directory containing OBJ files')
    parser.add_argument('--camera_json', type=str, required=True, help='Path to JSON file containing camera parameters')
    parser.add_argument('--mesh_translation', type=float, nargs=3, required=True,
                        help='Translation of the mesh [x, y, z]')
    parser.add_argument('--mesh_rotation', type=float, nargs=3, required=True,
                        help='Rotation of the mesh in degrees [x, y, z]')
    parser.add_argument('--mesh_scale', type=float, required=True, help='Scale of the mesh')
    parser.add_argument('--resolution', type=int, nargs=2, required=True,
                        help='Resolution of the rendered images [width, height]')
    parser.add_argument('--output_dir', type=str, required=True, help='Output directory for the rendered masks')
    parser.add_argument('--render_normal', action='store_true', help='Render surface normal map')
    parser.add_argument('--render_depth', action='store_true', help='Render depth map')
    parser.add_argument('--normal_format', type=str, default='gl', choices=['gl', 'dx'],
                        help='Normal map format: gl (OpenGL) or dx (DirectX)')

    argv = sys.argv[sys.argv.index("--") + 1:] if "--" in sys.argv else []
    return parser.parse_args(argv)


def apply_mesh_transform(obj, mesh_translation, mesh_scale, mesh_rotation):
    # Apply translation
    obj.location = mesh_translation

    # Apply scale
    obj.scale = [mesh_scale] * 3

    # Apply rotation
    rotation_radians = [math.radians(angle) for angle in mesh_rotation]

    obj.rotation_mode = 'XYZ'
    obj.rotation_euler = rotation_radians
    bpy.ops.object.transform_apply(location=False, rotation=True, scale=False)


def setup_camera(camera_position, camera_target, camera_up):
    bpy.ops.object.select_all(action='DESELECT')

    for o in bpy.context.scene.objects:
        if o.type == 'CAMERA':
            o.select_set(True)
    if bpy.context.selected_objects:
        bpy.ops.object.delete()

    bpy.ops.object.camera_add()
    camera = bpy.context.active_object

    # Set camera type and related parameters
    camera.data.type = "ORTHO"
    camera.data.ortho_scale = 2.0

    camera.location = Vector(camera_position)

    # Point camera at target
    direction = Vector(camera_target) - Vector(camera_position)
    # Calculate rotation to point at target
    rot_quat = direction.to_track_quat('-Z', 'Y')
    camera.rotation_euler = rot_quat.to_euler()

    bpy.context.scene.camera = camera

    # Enter camera view
    bpy.context.view_layer.objects.active = camera
    for area in bpy.context.screen.areas:
        if area.type == 'VIEW_3D':
            area.spaces[0].region_3d.view_perspective = 'CAMERA'

    return camera


def setup_scene(resolution):
    """Setup render settings and scene."""
    scene = bpy.context.scene
    scene.render.engine = 'CYCLES'
    scene.render.film_transparent = True
    scene.render.resolution_x = resolution[0]
    scene.render.resolution_y = resolution[1]
    scene.render.resolution_percentage = 100

    # Optimize Cycles for speed
    scene.cycles.samples = 64
    scene.cycles.use_denoising = False

    scene.display_settings.display_device = 'None'  # 禁用显示变换
    scene.view_settings.view_transform = 'Standard'  # 使用标准变换而非Filmic
    scene.view_settings.look = 'None'  # 禁用额外的外观处理
    scene.view_settings.exposure = 0  # 中性曝光
    scene.view_settings.gamma = 1.0  # 线性gamma

    # 设置输出图像格式
    scene.render.image_settings.file_format = 'PNG'
    scene.render.image_settings.color_mode = 'RGB'
    scene.render.image_settings.color_depth = '16'  # 使用16位PNG获得更高精度

    # Set world background to transparent
    world = bpy.data.worlds['World']
    world.use_nodes = True
    world.node_tree.nodes["Background"].inputs[0].default_value = (0, 0, 0, 0)
    world.node_tree.nodes["Background"].inputs[1].default_value = 0


def create_mask_material():
    """Create a white emission material for masks"""
    mat = bpy.data.materials.new(name="MaskMaterial")
    mat.use_nodes = True
    nodes = mat.node_tree.nodes

    # Clear default nodes
    nodes.clear()

    # Create emission node
    emission = nodes.new(type='ShaderNodeEmission')
    emission.inputs[0].default_value = (1, 1, 1, 1)  # White
    emission.inputs[1].default_value = 1  # Strength

    # Create output node
    output = nodes.new(type='ShaderNodeOutputMaterial')

    # Link nodes
    links = mat.node_tree.links
    links.new(emission.outputs[0], output.inputs[0])

    return mat


def create_blocker_material():
    """Create a black diffuse material that blocks view but does not contribute to mask."""
    mat = bpy.data.materials.new(name="BlockerMaterial")
    mat.use_nodes = True
    nodes = mat.node_tree.nodes
    nodes.clear()

    # Create diffuse BSDF node, set to black
    diffuse = nodes.new(type='ShaderNodeBsdfDiffuse')
    diffuse.inputs['Color'].default_value = (0, 0, 0, 1)

    # Create output node
    output = nodes.new(type='ShaderNodeOutputMaterial')

    # Link nodes
    links = mat.node_tree.links
    links.new(diffuse.outputs['BSDF'], output.inputs['Surface'])

    return mat


def create_default_material(index):
    """Create a default material with name mat_X where X is the index"""
    mat_name = f"mat_{index}"
    mat = bpy.data.materials.new(name=mat_name)
    mat.use_nodes = True
    return mat


def set_camera_space_surface_normal_material(obj, normal_format='gl'):
    """Set material to render camera space surface normals, choosing between OpenGL or DirectX format."""
    mat = bpy.data.materials.new("CameraSpaceNormalMaterial")
    mat.use_nodes = True

    nodes = mat.node_tree.nodes
    links = mat.node_tree.links
    nodes.clear()

    geometry_node = nodes.new(type='ShaderNodeNewGeometry')  # Geometry node
    vector_transform_node = nodes.new(type='ShaderNodeVectorTransform')  # Vector transform node
    multiply_node_adjust = nodes.new(type='ShaderNodeVectorMath')  # Adjust normal direction
    math_node_add = nodes.new(type='ShaderNodeVectorMath')  # Offset
    math_node_multiply = nodes.new(type='ShaderNodeVectorMath')  # Scale
    emission_node = nodes.new(type='ShaderNodeEmission')  # Emission node
    output_node = nodes.new(type='ShaderNodeOutputMaterial')  # Output node

    multiply_node_adjust.operation = 'MULTIPLY'
    multiply_node_adjust.inputs[1].default_value = (1.0, 1.0, -1.0)

    math_node_add.operation = 'ADD'
    math_node_add.inputs[1].default_value = (1.0, 1.0, 1.0)

    math_node_multiply.operation = 'MULTIPLY'
    math_node_multiply.inputs[1].default_value = (0.5, 0.5, 0.5)

    emission_node.inputs['Strength'].default_value = 1.0  # Set emission strength to 1.0

    vector_transform_node.vector_type = 'NORMAL'
    vector_transform_node.convert_from = 'WORLD'
    vector_transform_node.convert_to = 'CAMERA'

    links.new(geometry_node.outputs['Normal'],
              vector_transform_node.inputs['Vector'])  # Geometry Normal -> Vector Transform
    links.new(vector_transform_node.outputs['Vector'],
              multiply_node_adjust.inputs[0])  # Transformed Normal -> Adjust direction
    links.new(multiply_node_adjust.outputs[0], math_node_add.inputs[0])  # Adjusted Normal -> Offset
    links.new(math_node_add.outputs[0], math_node_multiply.inputs[0])  # Offset Normal -> Scale

    if normal_format == 'dx':
        dx_math_node_multiply = nodes.new(type='ShaderNodeVectorMath')
        dx_math_node_add = nodes.new(type='ShaderNodeVectorMath')
        dx_math_node_multiply.operation = 'MULTIPLY'
        dx_math_node_multiply.inputs[1].default_value = (-1.0, 1.0, 1.0)
        dx_math_node_add.operation = 'ADD'
        dx_math_node_add.inputs[1].default_value = (1.0, 0.0, 0.0)

        links.new(math_node_multiply.outputs[0], dx_math_node_multiply.inputs[0])  # Scaled Normal -> DirectX Scale
        links.new(dx_math_node_multiply.outputs[0], dx_math_node_add.inputs[0])
        links.new(dx_math_node_add.outputs[0], emission_node.inputs['Color'])
    else:
        links.new(math_node_multiply.outputs[0], emission_node.inputs['Color'])  # Scaled Normal -> Emission Color

    links.new(emission_node.outputs['Emission'], output_node.inputs['Surface'])  # Emission -> Output Surface

    # Apply this material to all material slots of the object
    for i in range(len(obj.material_slots)):
        obj.material_slots[i].material = mat

    print(f"Camera space surface normal material applied (format: {normal_format}).")

    return mat


def set_depth_material(obj, camera_distance=2.0):
    """
    Set material to render depth map.
    Near and far plane distances are calculated based on camera distance.
    """
    mat = bpy.data.materials.new("DepthMaterial")
    mat.use_nodes = True

    nodes = mat.node_tree.nodes
    links = mat.node_tree.links
    nodes.clear()

    camera_data = nodes.new(type='ShaderNodeCameraData')
    map_range = nodes.new(type='ShaderNodeMapRange')
    emission_node = nodes.new(type='ShaderNodeEmission')
    output_node = nodes.new(type='ShaderNodeOutputMaterial')

    # Calculate depth range
    # Assuming the camera looks towards the origin (0,0,0), and the mesh is normalized within a [-1,1] cube
    near_clip = max(0.1, camera_distance - 1.5)
    far_clip = camera_distance + 1.5

    # Set Map Range node parameters
    map_range.inputs['From Min'].default_value = near_clip
    map_range.inputs['From Max'].default_value = far_clip
    map_range.inputs['To Min'].default_value = 1.0  # Near is white
    map_range.inputs['To Max'].default_value = 0.0  # Far is black

    # Connect nodes
    links.new(camera_data.outputs['View Distance'], map_range.inputs['Value'])
    links.new(map_range.outputs[0], emission_node.inputs['Color'])
    links.new(emission_node.outputs['Emission'], output_node.inputs['Surface'])

    # Apply this material to all material slots of the object
    for i in range(len(obj.material_slots)):
        obj.material_slots[i].material = mat

    print(f"Depth material applied with near clip: {near_clip:.2f}, far clip: {far_clip:.2f}")

    return mat


def join_objects(objects):
    """Join all objects into one and maintain material assignments"""
    if not objects:
        return None

    bpy.ops.object.select_all(action='DESELECT')

    # Set the first object as active
    bpy.context.view_layer.objects.active = objects[0]

    # Select all objects
    for obj in objects:
        obj.select_set(True)

    # Join objects
    bpy.ops.object.join()

    return bpy.context.active_object


def render_mask(obj_files, obj_idx_list, camera_json, mesh_translation, mesh_scale, mesh_rotation,
                resolution, output_dir, render_normal=True, render_depth=True, normal_format='gl'):
    # Clear existing objects
    bpy.ops.object.select_all(action='SELECT')
    bpy.ops.object.delete()

    if not obj_files:
        print("No OBJ files found.")
        return

    imported_objects = []

    # Import all OBJ files and assign default materials
    for obj_file, index in zip(obj_files, obj_idx_list):
        bpy.ops.import_scene.obj(filepath=obj_file)
        imported_obj = bpy.context.selected_objects[0]

        # Create and assign default material
        mat = create_default_material(index)
        if len(imported_obj.data.materials) == 0:
            imported_obj.data.materials.append(mat)
        else:
            imported_obj.data.materials[0] = mat

        imported_objects.append(imported_obj)

    # Join all objects
    combined_obj = join_objects(imported_objects)
    if not combined_obj:
        print("Failed to combine objects")
        return

    # Set up the scene and camera
    setup_scene(resolution)
    # Apply transformation to the combined object
    apply_mesh_transform(combined_obj, mesh_translation, mesh_scale, mesh_rotation)

    # shading smooth
    combined_obj.data.use_auto_smooth = True
    for poly in combined_obj.data.polygons:
        poly.use_smooth = True

    # Load camera json
    # If camera_json is a string, load it as a file path
    if isinstance(camera_json, str):
        camera_json_path = camera_json
        if not os.path.isabs(camera_json_path):
            camera_json_path = os.path.join(output_dir, camera_json)
        camera_data = json.load(open(camera_json_path))
    # If camera_json is a dictionary, use it directly
    elif isinstance(camera_json, dict):
        camera_data = camera_json
    else:
        raise ValueError("camera_json must be a path to a JSON file or a dictionary")

    camera_position = camera_data['camera_position']
    camera_target = camera_data['camera_target']
    camera_up = camera_data['camera_up']

    setup_camera(camera_position, camera_target, camera_up)

    # Generate masks for each material
    material_slots = combined_obj.material_slots
    if not material_slots:
        print("No materials found in combined object")
        return

    # Create mask and blocker materials
    mask_material = create_mask_material()
    blocker_material = create_blocker_material()

    # Create output directory
    os.makedirs(output_dir, exist_ok=True)

    # Generate mask for each material
    for slot in material_slots:
        if not slot.material:
            continue

        original_material = slot.material
        material_index = int(original_material.name.split('_')[1])
        mask_output_name = f"mask_{material_index}.png"

        print(f"Generating mask for {original_material.name} as {mask_output_name}...")

        # Store original materials
        original_materials = [s.material for s in material_slots]

        # Assign mask material to target slot and blocker to others
        for s in material_slots:
            if s == slot:
                s.material = mask_material
            else:
                s.material = blocker_material

        # Set render output path
        output_path = os.path.join(output_dir, mask_output_name)
        bpy.context.scene.render.filepath = output_path

        # Render the mask
        bpy.ops.render.render(write_still=True)

        # Restore original materials
        for s, m in zip(material_slots, original_materials):
            s.material = m

    # Render Surface Normal map if requested
    if render_normal:
        print("Generating surface normal map...")
        # Apply normal material to all parts of the mesh
        normal_material = set_camera_space_surface_normal_material(combined_obj, normal_format)

        # Set render output path
        output_path = os.path.join(output_dir, "sh_normal.png")
        bpy.context.scene.render.filepath = output_path

        # Render
        bpy.ops.render.render(write_still=True)

        # Clean up
        bpy.data.materials.remove(normal_material)

    # Render Depth map if requested
    if render_depth:
        print("Generating depth map...")
        # Calculate camera distance
        # Apply depth material to all parts of the mesh
        depth_material = set_depth_material(combined_obj, 2.0)

        # Set render output path
        output_path = os.path.join(output_dir, "depth.png")
        bpy.context.scene.render.filepath = output_path

        # Render
        bpy.ops.render.render(write_still=True)

        # Clean up
        bpy.data.materials.remove(depth_material)

    # Clean up temporary materials
    bpy.data.materials.remove(mask_material)
    bpy.data.materials.remove(blocker_material)


if __name__ == '__main__':
    args = get_args()
    objs = sorted(glob(f"{args.obj_dir}/obj_*.obj"))
    obj_idx_list = [int(re.search(r'\d+', os.path.basename(obj_file)).group(0)) for obj_file in objs]
    print("rendering masks for objects:", objs)
    render_mask(
        objs,
        obj_idx_list,
        args.camera_json,
        args.mesh_translation,
        args.mesh_scale,
        args.mesh_rotation,
        args.resolution,
        args.output_dir,
    )