import numpy as np
import os
import open3d as o3d
from PIL import Image

def normalize_mesh_and_cloud(mesh, cloud):
    mesh_scale = np.linalg.norm(mesh.get_max_bound() - mesh.get_min_bound())
    cloud_scale = np.linalg.norm(cloud.get_max_bound() - cloud.get_min_bound())
    cloud.scale(mesh_scale / cloud_scale, center=cloud.get_center())
    cloud_pts = np.asarray(cloud.points)
    mesh_verts = np.asarray(mesh.vertices)
    cloud.points = o3d.utility.Vector3dVector(cloud_pts - cloud.get_min_bound()/2 - cloud.get_max_bound() / 2)
    mesh.vertices = o3d.utility.Vector3dVector((mesh_verts - mesh.get_min_bound()/2 - mesh.get_max_bound()/2))
    return mesh, cloud

#render_transparent_mesh_with_points(mesh, cloud, -1.1, 1, -0.5, 1, camera_offset=[-250, 0, 0], skip_pcd=True, up=[0, 1, 0], mesh_alpha=1)
def render_transparent_mesh_with_points(mesh, pcd, left, right, bottom, top,
                                        output_path="figures/test_visual.png",
                                        width=800, height=600,
                                        point_size=5.0,
                                        mesh_color=[0.8,0.8,0.8],
                                        mesh_alpha=0.3,
                                        ortho_scale=2.0,
                                        up=[0, 0,1],
                                        camera_offset=np.array([1,1,1]),
                                        skip_pcd = False):
    # Color mesh in place

    # Create renderer
    renderer = o3d.visualization.rendering.OffscreenRenderer(width, height)

    # Background
    renderer.scene.set_background([1,1,1,1])
    renderer.scene.show_skybox(False)

    # Mesh material
    mesh_mat = o3d.visualization.rendering.MaterialRecord()
    mesh_mat.shader = "defaultLitTransparency"
    mesh_mat.base_color = mesh_color + [mesh_alpha]

    # Point cloud material
    pcd_mat = o3d.visualization.rendering.MaterialRecord()
    pcd_mat.shader = "defaultUnlit"
    pcd_mat.point_size = point_size

    # Add geometries
    renderer.scene.add_geometry("mesh", mesh, mesh_mat)
    if not skip_pcd:
        renderer.scene.add_geometry("points", pcd, pcd_mat)

    # Camera setup
    center = mesh.get_center()
    eye = center + camera_offset
    up = np.array(up)
    camera = renderer.scene.camera
    camera.look_at(center, eye, up)

    # --- Orthographic projection ---
    left = left
    right = right
    bottom = bottom
    top = top
    near = 0.01
    far = 1000

    camera.set_projection(
        projection_type=o3d.visualization.rendering.Camera.Projection.Ortho,
        left=left,
        right=right,
        bottom=bottom,
        top=top,
        near=near,
        far=far
    )

    renderer.scene.scene.enable_sun_light(True)
    # renderer.scene.view.set_post_processing(False)

    # Render and save
    img = np.asarray(renderer.render_to_image())
    depth = np.asarray(renderer.render_to_depth_image())
    mask = depth == 1
    img[mask] = (np.ones_like(img, dtype=np.uint8) * 255)[mask]

    Image.fromarray(img, mode='RGB').save(output_path)

    # o3d.io.write_image(output_path, img, 9)
    print(f"Saved transparent rendered image to {output_path}")

def get_color_strengths(cloud):
    cloud_colors = np.asarray(cloud.colors)
    strengths = np.repeat(cloud_colors[:, 0][:, np.newaxis], 3, axis=1)
    return strengths

def recolor_cloud(cloud, strengths, color=[1, 0, 1], base_color = [0.5, 0.5, 0.5], magnify_strengths=1):
    cloud_colors = np.asarray(cloud.colors)
    color = np.repeat(np.array(color)[np.newaxis, :], cloud_colors.shape[0], axis=0)
    base_color = np.repeat(np.array(base_color)[np.newaxis, :], cloud_colors.shape[0], axis=0)

    strengths = strengths.copy() *  2
    strengths = strengths ** (magnify_strengths)

    cloud_colors = color * strengths + base_color * (1 - strengths)
    cloud.colors = o3d.utility.Vector3dVector(cloud_colors)
    return cloud

def prep_mesh_and_cloud(mesh_path, cloud_path, reflection = [1, 1, 1], axes=[0, 1, 2],color=[1,0,0], base_color=[0.6, 0.6,0.6], 
                        magnify_strengths=1, mesh_color1=(0, 1, 0), mesh_color2=(0, 0, 1), axis='z', mesh_grad=True):
    cloud = o3d.io.read_point_cloud(cloud_path)
    mesh = o3d.io.read_triangle_mesh(mesh_path)
    mesh, cloud = normalize_mesh_and_cloud(mesh, cloud)

    mesh_verts = np.asarray(mesh.vertices)[:, axes]
    mesh.vertices = o3d.utility.Vector3dVector(mesh_verts * np.repeat(np.array(reflection)[np.newaxis, :], mesh_verts.shape[0], axis=0))

    mesh.compute_vertex_normals()
    mesh = color_mesh_gradient(mesh)

    strs = get_color_strengths(cloud)
    cloud = recolor_cloud(cloud, strs, color=color, base_color=base_color, magnify_strengths=magnify_strengths)
    return mesh, cloud

def color_mesh_gradient(mesh, axis="z", color_start=(0, 1, 0), color_end=(0, 0, 1)):
    """
    Colors mesh vertices with a gradient along the given axis.

    Parameters
    ----------
    mesh : o3d.geometry.TriangleMesh
        The input mesh to color. It will be modified in-place.
    axis : str, optional
        Axis to apply gradient along: "x", "y", or "z".
    color_start : tuple
        RGB tuple for the minimum axis value (each in [0,1]).
    color_end : tuple
        RGB tuple for the maximum axis value (each in [0,1]).

    Returns
    -------
    mesh : o3d.geometry.TriangleMesh
        The mesh with vertex colors assigned.
    """

    # Extract vertex positions
    verts = np.asarray(mesh.vertices)
    
    # Map axis name to column index
    axis_map = {"x": 0, "y": 1, "z": 2}
    if axis not in axis_map:
        raise ValueError("axis must be 'x', 'y', or 'z'")
    axis_idx = axis_map[axis]

    # Normalize vertex positions along axis [0,1]
    vals = verts[:, axis_idx]
    min_val, max_val = vals.min(), vals.max()
    if max_val == min_val:
        t = np.zeros_like(vals)  # avoid div-by-zero (flat mesh)
    else:
        t = (vals - min_val) / (max_val - min_val)

    # Interpolate colors
    c0 = np.array(color_start)
    c1 = np.array(color_end)
    colors = (1 - t)[:, None] * c0 + t[:, None] * c1

    mesh.vertex_colors = o3d.utility.Vector3dVector(colors)
    return mesh
