import os
import re
import torch
import numpy as np
from torchvision.utils import save_image
import drjit as dr
import mitsuba as mi
mi.set_variant('cuda_ad_rgb')
from mitsuba import ScalarTransform4f as T, ScalarPoint3f, Point3f, Transform4f as DT
from utils import get_obj_files_list
from materials import UniformMaterial, ProceduralMaterial
from config import MATERIAL_INIT_PATH, HDR_ROOT_DIR

def create_texture_bitmap(texture_path, uv_scale=(1.0, 1.0), uv_rotation=0.0, raw=False, uv_flip=True):
    """创建Mitsuba位图纹理，支持UV缩放和旋转"""
    # 计算UV变换矩阵
    scale_x, scale_y = uv_scale

    # UV变换矩阵 (使用mi.ScalarTransform3f创建2D变换)
    transform = T().scale([scale_x, scale_y, 1.0]).rotate([0, 0, 1], uv_rotation)
    if uv_flip:
        transform = transform.scale([1, -1, 1])

    return {
        'type': 'bitmap',
        'filename': texture_path,
        'raw': raw,
        'to_uv': transform
    }


def create_object_transform(translation, scale, rotation):
    """创建对象变换矩阵，按照Scale->Rotate->Translate的顺序应用"""
    transform = T()

    transform = transform.scale([scale, scale, scale])
    transform = transform.rotate([1, 0, 0], rotation[0]).rotate([0, 1, 0], rotation[1]).rotate([0, 0, 1], rotation[2])
    transform = transform.translate(translation)

    return transform


def create_scene_dict(obj_files, obj_idx_list, camera_dict, material_pair_dict, hdr_path, resolution,
                      uv_scale, uv_rotation, translation, scale, rotation,
                      seamless_texture_paths=None, seamless_texture_idx=None):
    # 创建基本场景字典
    scene_dict = {
        'type': 'scene',
        'emitter': {
            'type': 'envmap',
            'filename': hdr_path,
            'scale': 1.0,
            'to_world': T().rotate(ScalarPoint3f([1, 0, 0]), 90.0)
        },
        'sensor': {
            'type': 'orthographic',
            'near_clip': 0.01,
            'far_clip': 100,
            'to_world': T().look_at(
                origin=ScalarPoint3f(camera_dict['camera_position']),
                target=ScalarPoint3f(camera_dict['camera_target']),
                up=ScalarPoint3f(camera_dict['camera_up'])
            ) @ T().scale(ScalarPoint3f([1, 1, 1])),

            'film': {
                'type': 'hdrfilm',
                'width': resolution[0],
                'height': resolution[1],
                'filter': {'type': 'gaussian'}
            },
        },
        'integrator': {
            'type': 'path',
            'max_depth': 5,
            'hide_emitters': True
        },
    }

    opt_param_keys = []

    procedural_materials = {}

    # 添加所有物体，使用albedo材质
    for idx, obj_file in zip(obj_idx_list, obj_files):
        material_path = material_pair_dict[f'{idx}']['material_path']
        procedural_config = material_pair_dict[f'{idx}']['procedural']

        # 创建对象变换
        object_transform = create_object_transform(translation, scale, rotation)

        # 检查是否使用无缝纹理，并获取对应的纹理路径
        if seamless_texture_idx and idx in seamless_texture_idx and seamless_texture_paths and idx in seamless_texture_paths:
            albedo_path = seamless_texture_paths[idx]
            print(f"Using seamless texture for obj_{idx}: {albedo_path}")
        else:
            albedo_path = None

        print(f"Creating Material for obj_{idx}...")
        if procedural_config is not None:

            # # 检查是否使用无缝纹理，并获取对应的纹理路径
            # if seamless_texture_idx and idx in seamless_texture_idx and seamless_texture_paths and idx in seamless_texture_paths:
            #     albedo_path = seamless_texture_paths[idx]
            #     print(f"Using seamless texture for obj_{idx}: {albedo_path}")
            # else:
            #     albedo_path = None

            procedural_materials[f'{idx}'] = ProceduralMaterial(procedural_config['sbs_file_path'],
                                                                procedural_config['mgt_res'],
                                                                procedural_config['external_input_path'],
                                                                procedural_config['ckp_path'],
                                                                init_scale=uv_scale[0], init_rotation=uv_rotation,
                                                                albedo_image_path=albedo_path,)
        else:
            procedural_materials[f'{idx}'] = UniformMaterial(albedo_image_path=albedo_path)

        # 创建材质贴图
        if material_path is not None:
            albedo_texture = create_texture_bitmap(material_path['albedo'], raw=False, uv_flip=False)
            metallic_texture = create_texture_bitmap(material_path['metallic'], raw=True, uv_flip=False)
            roughness_texture = create_texture_bitmap(material_path['roughness'], raw=True, uv_flip=False)
            normal_texture = create_texture_bitmap(material_path['normal'], raw=True, uv_flip=False)

            # 创建物体字典
            obj_key = f'obj_{idx}'
            scene_dict[obj_key] = {
                'type': 'obj',
                'filename': obj_file,
                'to_world': object_transform,
                'bsdf': {
                    'type': 'twosided',
                    'nested': {
                        'type': 'normalmap',
                        'normalmap': normal_texture,
                        'nested': {
                            'type': 'principled',
                            'base_color': albedo_texture,
                            'metallic': metallic_texture,
                            'roughness': roughness_texture,
                        }
                    }
                }
            }
            opt_param_keys.append(f"{obj_key}.bsdf.brdf_0.normalmap.data")
            opt_param_keys.append(f"{obj_key}.bsdf.brdf_0.nested_bsdf.base_color.data")
            opt_param_keys.append(f"{obj_key}.bsdf.brdf_0.nested_bsdf.metallic.data")
            opt_param_keys.append(f"{obj_key}.bsdf.brdf_0.nested_bsdf.roughness.data")
        else:
            # 创建物体字典
            obj_key = f'obj_{idx}'
            scene_dict[obj_key] = {
                'type': 'obj',
                'filename': obj_file,
                'to_world': object_transform,
                'bsdf': {
                    'type': 'twosided',
                    'nested': {
                        'type': 'principled',
                        'base_color': {
                            'type': 'rgb',
                            'value': [0.5,0.5,0.5]
                        },
                        'metallic': 0.0,
                        'roughness': 0.5,
                    }
                }
            }
            opt_param_keys.append(f"{obj_key}.bsdf.brdf_0.base_color.value")
            opt_param_keys.append(f"{obj_key}.bsdf.brdf_0.metallic.value")
            opt_param_keys.append(f"{obj_key}.bsdf.brdf_0.roughness.value")
    return scene_dict, opt_param_keys, procedural_materials

def create_scene_dict_drjit(obj_files, obj_idx_list, camera_dict, material_path, hdr_path, resolution,
                      translation, scale, rotation):
    # 创建基本场景字典
    scene_dict = {
        'type': 'scene',
        'emitter': {
            'type': 'envmap',
            'filename': hdr_path,
            'scale': 1.0,
            'to_world': T().rotate(ScalarPoint3f([1, 0, 0]), 90.0)
        },
        'sensor': {
            'type': 'orthographic',
            'near_clip': 0.01,
            'far_clip': 100,
            'to_world': T().look_at(
                origin=ScalarPoint3f(camera_dict['camera_position']),
                target=ScalarPoint3f(camera_dict['camera_target']),
                up=ScalarPoint3f(camera_dict['camera_up'])
            ) @ T().scale(ScalarPoint3f([1, 1, 1])),

            'film': {
                'type': 'hdrfilm',
                'width': resolution[0],
                'height': resolution[1],
                'filter': {'type': 'gaussian'},
            },
        },
        'integrator': {
            # 'type': 'path',
            'type': 'prb',
            'max_depth': 5,
            'hide_emitters': True,
        },
        'sampler': {
            'type': 'stratified',
            'sample_count': 4
        },
    }

    opt_param_keys = []

    object_transform = create_object_transform(translation, scale, rotation)

    # 添加所有物体，使用albedo材质
    for idx, obj_file in zip(obj_idx_list, obj_files):
        albedo_texture = create_texture_bitmap(material_path['albedo'], raw=False, uv_flip=False)
        metallic_texture = create_texture_bitmap(material_path['metallic'], raw=True, uv_flip=False)
        roughness_texture = create_texture_bitmap(material_path['roughness'], raw=True, uv_flip=False)
        normal_texture = create_texture_bitmap(material_path['normal'], raw=True, uv_flip=False)

        # 创建物体字典
        obj_key = f'obj_{idx}'
        scene_dict[obj_key] = {
            'type': 'obj',
            'filename': obj_file,
            'to_world': object_transform,
            'bsdf': {
                'type': 'twosided',
                'nested': {
                    'type': 'normalmap',
                    'normalmap': normal_texture,
                    'nested': {
                        'type': 'principled',
                        'base_color': albedo_texture,
                        'metallic': metallic_texture,
                        'roughness': roughness_texture,
                    }
                }
            }
        }
        opt_param_keys.append(f"{obj_key}.bsdf.brdf_0.normalmap.data")
        opt_param_keys.append(f"{obj_key}.bsdf.brdf_0.nested_bsdf.base_color.data")
        opt_param_keys.append(f"{obj_key}.bsdf.brdf_0.nested_bsdf.metallic.data")
        opt_param_keys.append(f"{obj_key}.bsdf.brdf_0.nested_bsdf.roughness.data")

    return scene_dict, opt_param_keys


def create_scene_drjit(obj_file_path, obj_trans, obj_scale, obj_rotate, camera_dict, resolution=(512, 512)):
    hdr_path = f"{HDR_ROOT_DIR}/uniform.exr"
    obj_files = get_obj_files_list(obj_file_path)
    obj_idx_list = [int(re.search(r'\d+', os.path.basename(obj_file)).group(0)) for obj_file in obj_files]
    material_path = {
        'albedo': f'{MATERIAL_INIT_PATH}/albedo.png',
        'metallic': f'{MATERIAL_INIT_PATH}/metallic.png',
        'roughness': f'{MATERIAL_INIT_PATH}/roughness.png',
        'normal': f'{MATERIAL_INIT_PATH}/normal.png'
    }

    # 创建基本场景字典
    scene_dict, opt_param_keys = create_scene_dict_drjit(obj_files, obj_idx_list, camera_dict, material_path,
                                                            hdr_path, resolution, obj_trans, obj_scale, obj_rotate)

    # 创建和渲染场景
    scene = mi.load_dict(scene_dict)

    return scene, opt_param_keys, obj_files, obj_idx_list


### Rendering Utils
@dr.wrap(source='torch', target='drjit')
def render_scene(scene, params, material_dict, emitter_data=None, sensor=0,
                 spp=32, spp_grad=32, seed=0, seed_grad=0):
    # 更新光照
    if emitter_data is not None:
        if isinstance(emitter_data, torch.Tensor):
            emitter_data = mi.TensorXf(emitter_data)
        params['emitter.data'] = emitter_data

    for obj_idx, mat in material_dict.items():
        if material_dict[obj_idx]['type'] == 'uniform':
            base_color = mi.TensorXf(material_dict[obj_idx]['albedo']) ** 2.2
            metallic = mi.TensorXf(material_dict[obj_idx]['metallic'])
            roughness = mi.TensorXf(material_dict[obj_idx]['roughness'])
            params[f"obj_{obj_idx}.bsdf.brdf_0.base_color.value"] = base_color
            params[f"obj_{obj_idx}.bsdf.brdf_0.metallic.value"] = metallic
            params[f"obj_{obj_idx}.bsdf.brdf_0.roughness.value"] = roughness
        else: # material_dict[obj_idx]['type'] == 'procedural':
            normal_data = mi.TensorXf(material_dict[obj_idx]['normal'])
            albedo_data = mi.TensorXf(material_dict[obj_idx]['albedo']) ** 2.2
            metallic_data = mi.TensorXf(material_dict[obj_idx]['metallic'])
            roughness_data = mi.TensorXf(material_dict[obj_idx]['roughness'])

            params[f"obj_{obj_idx}.bsdf.brdf_0.normalmap.data"] = normal_data
            params[f"obj_{obj_idx}.bsdf.brdf_0.nested_bsdf.base_color.data"] = albedo_data
            params[f"obj_{obj_idx}.bsdf.brdf_0.nested_bsdf.metallic.data"] = metallic_data
            params[f"obj_{obj_idx}.bsdf.brdf_0.nested_bsdf.roughness.data"] = roughness_data

    params.update()

    params.update()
    img = mi.render(scene, params=params, sensor=sensor, spp=spp, spp_grad=spp_grad, seed=seed, seed_grad=seed_grad)

    return img


def render_scene_drjit(scene, params, sensor=0,
                 spp=32, spp_grad=32, seed=0, seed_grad=0):
    img = mi.render(scene, params=params, sensor=sensor, spp=spp, spp_grad=spp_grad, seed=seed, seed_grad=seed_grad)

    return img


@dr.wrap(source='torch', target='drjit')
def relighting_scene(scene, params, materials, obj_idx_list, emitter_data, spp=32):
    """
    Render a scene with the given materials and a specific lighting setup.

    Args:
        scene: Mitsuba scene
        params: Scene parameters
        materials: Material dictionary
        obj_idx_list: List of object indices
        emitter_data: Environment map data
        spp: Samples per pixel

    Returns:
        Rendered image
    """
    if emitter_data is not None:
        if isinstance(emitter_data, torch.Tensor):
            emitter_data = mi.TensorXf(emitter_data)
        params['emitter.data'] = emitter_data

    for obj_idx in obj_idx_list:
        normal_data = mi.TensorXf(materials[obj_idx]['normal'])
        albedo_data = mi.TensorXf(materials[obj_idx]['albedo'])
        metallic_data = mi.TensorXf(materials[obj_idx]['metallic'])
        roughness_data = mi.TensorXf(materials[obj_idx]['roughness'])

        params[f"obj_{obj_idx}.bsdf.brdf_0.normalmap.data"] = normal_data
        params[f"obj_{obj_idx}.bsdf.brdf_0.nested_bsdf.base_color.data"] = albedo_data
        params[f"obj_{obj_idx}.bsdf.brdf_0.nested_bsdf.metallic.data"] = metallic_data
        params[f"obj_{obj_idx}.bsdf.brdf_0.nested_bsdf.roughness.data"] = roughness_data

    params.update()
    return mi.render(scene, params=params, spp=spp)


def relighting_scene_drjit(scene, params, emitter_data, spp=32):
    """
    Render a scene with the given materials and a specific lighting setup.

    Args:
        scene: Mitsuba scene
        params: Scene parameters
        emitter_data: Environment map data
        spp: Samples per pixel

    Returns:
        Rendered image
    """
    if emitter_data is not None:
        if isinstance(emitter_data, torch.Tensor):
            emitter_data = mi.TensorXf(emitter_data)
        params['emitter.data'] = emitter_data

    params.update()
    return mi.render(scene, params=params, spp=spp)


def get_materials(material_dict, obj_idx_list):
    """
    Get materials for all objects in the list.

    Args:
        material_dict: Dictionary of material objects
        obj_idx_list: List of object indices

    Returns:
        dict: Dictionary of material data
    """
    materials = {}
    for obj_idx in obj_idx_list:
        mat = material_dict[obj_idx].evaluate()
        materials[obj_idx] = {
            'normal': mat['normal'],
            'albedo': mat['albedo'],
            'metallic': mat['metallic'],
            'roughness': mat['roughness']
        }
    return materials


def render_scene_torch(scene, params, spp=512):
    """
    Render scene and return as a torch tensor.

    Args:
        scene: Mitsuba scene
        params: Scene parameters
        spp: Samples per pixel

    Returns:
        torch.Tensor: Rendered image tensor
    """
    image = torch.from_numpy(np.array(mi.render(scene, params=params, spp=spp))).permute(2, 0, 1).unsqueeze(0)
    # Convert to linear space
    image = torch.pow(image.clamp(0.0, 1.0), 1 / 2.2)
    return image




