import os
import platform
if platform.system() == 'Linux':
    os.environ["PYOPENGL_PLATFORM"] = "osmesa"
import numpy as np
import trimesh
import pyrender
import argparse
from PIL import Image
from .util import *
import io

def look_at(point, pos, up):
        z = np.asarray(pos) - np.asarray(point)
        x = np.cross(up, z)
        y = np.cross(z, x)
        x = x / np.linalg.norm(x)
        y = y / np.linalg.norm(y)
        z = z / np.linalg.norm(z)
        pose = np.eye(4)
        pose[:3,0] = x
        pose[:3,1] = y
        pose[:3,2] = z
        pose[:3,3] = pos
        return pose

def wireframe_cube(color=None):
    if color is None:
        color = [0.78,0.78,0.78,1]
    elif len(color) == 3:
        color = color + [1]
    primitives = []
    for i in [0.,1.]:
        for j in [0.,1.]:
            for k in [0.,1.]:
                corner = [i,j,k]
                for idx in [0,1,2]:
                    other_corner = corner.copy()
                    other_corner[idx] = 1.0 - corner[idx]
                    primitives.append(pyrender.Primitive(positions=[corner, other_corner],mode=1, color_0=[color,color]))

    mesh = pyrender.Mesh(primitives = primitives)

    return mesh


def render_metamaterial(obj_path, r = None, color = None, box_color = None):

    if color is None:
        color = [.5,.5,.5] # [1.,1.,0.]
    if box_color is None:
        box_color = [1.0, 0.0, 0.0] # None

    res_factor = 1
    built_renderer = False
    if r is None:
        r = pyrender.OffscreenRenderer(512 * res_factor, 512 * res_factor)
        built_renderer = True
    
    with open_file(obj_path) as f:
        fuze_trimesh = trimesh.load(f, 'obj')

    fuze_trimesh.visual.vertex_colors = color

    mesh = pyrender.Mesh.from_trimesh(fuze_trimesh, smooth=True,  wireframe=False)
    scene = pyrender.Scene(ambient_light=[0,0,0],bg_color=[1.,1.,1.])

    camera = pyrender.OrthographicCamera(xmag=.7, ymag=.7)

    camera_locations = {
        'front': (np.array([0.5, 0.5, 2.0]),np.array([0.0, 1.0, 0.0])),
        'right': (np.array([2.0, 0.5, 0.5]),np.array([0.0, 1.0, 0.0])),
        'top': (np.array([0.5, 2.0, 0.5]),np.array([0.0, 0.0, -1.0])),
        'top_right': (np.array([.85, 1.0 ,2.0]), np.array([0.0, 1.0, 0.0]))
    }

    camera_pose = look_at(np.array([0.5, 0.5, 0.5]), camera_locations['top'][0], camera_locations['top'][1])
    camera_node = scene.add(camera, pose=camera_pose)

    light = pyrender.DirectionalLight([1,1,1], 1200)
    light_node = scene.add(light, pose=camera_pose)
    scene.set_pose(light_node, look_at([0,0,0], [10, 10, 10], [0,1,0]))
    
    renders = {}
    cube_mesh = wireframe_cube(box_color)
    scene.add(cube_mesh)
    scene.add(mesh)
    for location_name, camera_location in camera_locations.items():

        # Set up camera and lighting
        camera_pose = look_at(np.array([0.5, 0.5, 0.5]), camera_location[0], camera_location[1])
        scene.set_pose(camera_node, pose=camera_pose)
        scene.set_pose(light_node, pose=camera_pose)
        color, _ = r.render(scene)

        renders[location_name] = color
    top_row = np.concatenate([renders['top'],renders['top_right']],axis=1)
    bottom_row = np.concatenate([renders['front'],renders['right']],axis=1)
    combined = np.concatenate([top_row,bottom_row],axis=0)
    renders['all'] = combined

    if built_renderer:
         r.delete()
    return renders

def run_render(input, output, color = None, box_color = None):
    renders = render_metamaterial(input, color = color, box_color = box_color)
    if output.startswith('s3://'):
        s3path = output[len('s3://'):]
        s3parts = s3path.split('/')
        bucket = s3parts[0]
        key = '/'.join(s3parts[1:])
        key = key if key.endswith('/') else key + '/'
        for viewpoint, render in renders.items():
            image = Image.fromarray(render)
            key + f'{viewpoint}.png'
            buffer = io.BytesIO()
            image.save(buffer, format='png')
            write_s3(bucket, key + viewpoint + '.png', buffer.getvalue())
            buffer.close()
    else:
        os.makedirs(output, exist_ok=True)
        for viewpoint, render in renders.items():
            image = Image.fromarray(render)
            image.save(os.path.join(output, f'{viewpoint}.png'))

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('-i', type=str)
    parser.add_argument('-o', type=str)
    args = parser.parse_args()
    run_render(args.i, args.o)

if __name__ == '__main__':
    main()