import os
import mitsuba as mi
import numpy as np
# mi.set_variant('cuda_ad_rgb') # use GPU
# read the variable RENDER_W_GPU from the environment, if it is set to 1, use GPU, otherwise use CPU
if os.environ.get('RENDER_W_GPU') == '1':
    print("Rendering with GPU")
    mi.set_variant('cuda_ad_rgb')
else:
    print("Rendering with CPU")
    mi.set_variant('scalar_rgb') # use CPU in default, as the GPU may be not available, or applied to other tasks
from mitsuba import ScalarTransform4f as T

def mitsuba_render(xml_path, save_dir,
                   render_img_name="render",
                   img_type="png",
                   sensor_width = 512,
                    sensor_height = 512,
                    sensor_sep = 25,            
                    phi = -90,  
                    radius = 13.2,
                    theta = 15,
                    spp = 512):
    # Load a scene
    scene = mi.load_file(xml_path)

    

    def load_sensor(r, phi, theta):
        # Apply two rotations to convert from spherical coordinates to world 3D coordinates.
        origin = T.rotate([0, 0, 1], phi).rotate([0, 1, 0], theta) @ mi.ScalarPoint3f([0, 0, r])

        return mi.load_dict({
            'type': 'perspective',
            'fov': 39.3077,
            'to_world': T.look_at(
                origin=origin,
                target=[0, 0, 0],
                up=[0, 0, 1]
            ),
            'sampler': {
                'type': 'independent',
                'sample_count': 16
            },
            'film': {
                'type': 'hdrfilm',
                'width': sensor_width,
                'height': sensor_height,
                'rfilter': {
                    'type': 'tent',
                },
                'pixel_format': 'rgb',
            },
        })
    sensor = load_sensor(radius, phi, theta)
    image = mi.render(scene, spp=spp, sensor=sensor)
    # get numpy image
    # numpy_image = image.array
    # numpy_image /= numpy_image.max()
    # numpy_image = (numpy_image * 255).astype(np.uint8)
    # Write the rendered image to an EXR file
    if save_dir is not None:
        mi.util.write_bitmap(f"{save_dir}/{render_img_name}.{img_type}", image)
    
    return image #, numpy_image
    
def get_obj_files(scene_folder):
    obj_files = [f for f in os.listdir(scene_folder) if f.endswith('.obj')]
    return obj_files

# scene_folder = args.scene_folder
# # we find all the obj and mtl files in the scene folder, and write xml render file
# obj_files = [f for f in os.listdir(scene_folder) if f.endswith('.obj')]

def mi_write_img(save_path, img):
    mi.util.write_bitmap(save_path, img)


xml_head = \
"""
<scene version="0.6.0">
    <integrator type="path">
        <integer name="maxDepth" value="-1"/>
    </integrator>
    <sensor type="perspective">
        <float name="farClip" value="100"/>
        <float name="nearClip" value="0.1"/>
        <transform name="toWorld">
            <lookat origin="3,3,3" target="0,0,0" up="0,0,1"/>
        </transform>
        <float name="fov" value="25"/>
        
        <sampler type="ldsampler">
            <integer name="sampleCount" value="256"/>
        </sampler>
        <film type="hdrfilm">
            <integer name="width" value="1600"/>
            <integer name="height" value="1200"/>
            <rfilter type="gaussian"/>
            <boolean name="banner" value="false"/>
        </film>
    </sensor>
    
    <bsdf type="roughplastic" id="surfaceMaterial">
        <string name="distribution" value="ggx"/>
        <float name="alpha" value="0.05"/>
        <float name="intIOR" value="1.46"/>
        <rgb name="diffuseReflectance" value="1,1,1"/> <!-- default 0.5 -->
    </bsdf>
    
"""


xml_tail = \
"""
    <shape type="rectangle">
        <ref name="bsdf" id="surfaceMaterial"/>
        <transform name="toWorld">
            <scale x="10" y="10" z="1"/>
            <translate x="0" y="0" z="-0.5"/>
        </transform>
    </shape>
    
    <shape type="rectangle">
        <transform name="toWorld">
            <scale x="10" y="10" z="1"/>
            <lookat origin="-4,4,20" target="0,0,0" up="0,0,1"/>
        </transform>
        <emitter type="area">
            <rgb name="radiance" value="6,6,6"/>
        </emitter>
    </shape>
</scene>
"""

def load_one_mesh(obj_path, mtl_path=None, texture_path=None):
    obj_id = obj_path.split('_')[-1][:-4]
    
    if mtl_path:
        pass

    if texture_path:
        mesh_str = f"""
        <texture type="bitmap" id="{obj_id}_image">
            <string name="filename" value="{texture_path}"/>
        </texture>

        <bsdf type="diffuse" id="{obj_id}_material">
            <!-- Reference the texture named my_image and pass it
                to the BSDF as the reflectance parameter -->
            <ref name="reflectance" id="{obj_id}_image"/>
        </bsdf>

        <shape type="obj">
            <string name="filename" value="{obj_path}"/>

            <!-- Reference the material named my_material -->
            <ref id="{obj_id}_material"/>
        """
    
    else:    
        mesh_str = f"""
        <shape type="obj">
            <string name="filename" value="{obj_path}"/>
        <bsdf type="diffuse">
            <rgb name="reflectance" value="{0.5},{0.5},{0.5}"/>
        </bsdf>
        """
    
    mesh_str += f"""
    <transform name="toWorld">
        <rotate x="1" angle="90"/> <!-- Rotates the object 90 degrees around the X-axis -->
        <translate x="0" y="0" z="0"/>
    </transform>
    </shape>
    """
    return mesh_str

def load_mesh_list(obj_list, scene_folder):
    mesh_str = ""
    obj_list.sort()
    for obj_path in obj_list:
        obj_id = obj_path.split('_')[-1][:-4]
        mtl_path = os.path.join(scene_folder, 'material_' + obj_id + '.mtl')
        texture_path = os.path.join(scene_folder, 'material_' + obj_id + '.png')
        if os.path.exists(mtl_path):
            print('{} exists'.format(obj_id))
            mesh_str += load_one_mesh(obj_path, mtl_path, texture_path)
        else:
            mesh_str += load_one_mesh(obj_path)
    return mesh_str


def save_render_xml(obj_path_list, save_path, scene_folder):
    obj_load = load_mesh_list(obj_path_list, scene_folder)
    with open(save_path, "w") as f:
        f.write(xml_head)
        f.write(obj_load)
        f.write(xml_tail)


# save_render_xml([os.path.join(scene_folder, obj_file) for obj_file in obj_files], os.path.join(scene_folder, "render.xml"), scene_folder)
# mitsuba_render(os.path.join(scene_folder, "render.xml"), scene_folder)
# # remove xml
# os.remove(os.path.join(scene_folder, "render.xml"))

