from functools import partial
import open3d as o3d
import jax
import jax.numpy as jnp
import numpy as np
import mcubes
import einops
import os, sys
from tqdm import tqdm
from typing import Optional

BASEDIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
if BASEDIR not in sys.path:
    sys.path.insert(0, BASEDIR)

import util.latent_obj_util as loutil
import util.model_util as mutil
import util.camera_util as cutil
import util.transform_util as tutil

@partial(jax.jit, static_argnums=[1])
def FPS_padding(pnts, k, jkey):
    if k==1:
        return jnp.mean(pnts, axis=-2, keepdims=True)
    selected_pnt_set = pnts[...,0:1,:]

    # initial point
    dist = jnp.sum((pnts[...,None,:] - selected_pnt_set[...,None,:,:])**2, axis=-1)
    dist_min_values = jnp.min(dist, axis=-1)
    max_idx = jnp.argmax(dist_min_values, axis=-1)
    selected_pnt = jnp.take_along_axis(pnts, max_idx[...,None,None], axis=-2).squeeze(-2)
    selected_pnt_set = selected_pnt
    selected_pnt_set = einops.repeat(selected_pnt_set, '... i -> ... r i', r=k)

    def body_fun(i, carry):
        selected_pnt_set, dist_min_values = carry
        new_pnt = selected_pnt_set[...,i,:]
        dist = jnp.sum((pnts - new_pnt[...,None,:])**2, axis=-1)
        dist_min_values = jnp.minimum(dist_min_values, dist)
        max_idx = jnp.argmax(dist_min_values, axis=-1)
        selected_pnt = jnp.take_along_axis(pnts, max_idx[...,None,None], axis=-2).squeeze(-2)
        selected_pnt_set = selected_pnt_set.at[...,i+1,:].set(selected_pnt)
        return (selected_pnt_set, dist_min_values)

    selected_pnt_set, dist_min_values = jax.lax.fori_loop(0, k-1, body_fun, (selected_pnt_set, dist_min_values))

    return selected_pnt_set


def create_scene_mesh_from_oriCORNs(latent_obj:loutil.LatentObjects, dec=None, 
                                    level=0.0, qp_bound=None, density=200, ndiv=200, center=None, visualize=True, vert_out=False):
    
    if latent_obj is not None:
        latent_obj = latent_obj.get_valid_oriCORNs()
        if latent_obj.ndim > 0:
            latent_obj = latent_obj.reshape_outer_shape((-1,))
            latent_obj = latent_obj.merge()[None]
        AABB_target = latent_obj.AABB_fps
        center = jnp.mean(jnp.stack(AABB_target, 0), axis=0).squeeze(0)
        if qp_bound is None:
            qp_bounds = AABB_target[1] - AABB_target[0]
            qp_bounds = 1.2*qp_bounds.squeeze(0)
    elif qp_bound is None:
        qp_bound = 0.8
    
    if isinstance(qp_bound, (float, int)):
        qp_bounds = np.array([qp_bound, qp_bound, qp_bound])
    if isinstance(qp_bound, np.ndarray):
        if qp_bound.ndim == 0:
            qp_bounds = np.array([qp_bound, qp_bound, qp_bound])
        elif qp_bound.ndim == 1:
            assert qp_bound.shape[0] == 3
            qp_bounds = np.array(qp_bound)
        


    # if jkey is None:
    jkey = jax.random.PRNGKey(21)

    if dec is None:
        # import util.model_util as mutil
        models = mutil.Models().load_pretrained_models()
        dec = jax.jit(models.occ_prediction)

    if center is None:
        if latent_obj is not None:
            center = np.mean(latent_obj.pos.reshape(-1,3), axis=0)
        else:
            center = np.array([0.,0.,0.])

    # marching cube
    gaps = 2*qp_bounds / (density-1)
    x = np.linspace(-qp_bounds[0], qp_bounds[0], density, endpoint=True) + center[...,0]
    y = np.linspace(-qp_bounds[1], qp_bounds[1], density, endpoint=True) + center[...,1]
    z = np.linspace(-qp_bounds[2], qp_bounds[2], density, endpoint=True) + center[...,2]
    xv, yv, zv = np.meshgrid(x, y, z)
    grid = np.stack([xv, yv, zv]).astype(np.float32).reshape(3, -1).transpose()[None]
    grid = jnp.array(grid)
    output = None
    assert grid.shape[1] % ndiv == 0
    dif = grid.shape[1]//ndiv
    tri_param = latent_obj
    for i in tqdm(range(ndiv)):
        _, jkey = jax.random.split(jkey)
        grid_ = grid[:,dif*i:dif*(i+1)]
        output_ = dec(tri_param, grid_, jkey)
        if isinstance(output_, tuple) or isinstance(output_, list):
            output_ = np.max(output_[0], axis=0)
        else:
            output_ = np.max(output_, axis=0)
        if output is None:
            output = output_
        else:
            output = jnp.concatenate([output, output_], 0)
    volume = output.reshape(density, density, density).transpose(1, 0, 2)
    volume = np.array(volume)
    # print("start smoothing")
    # volume = mcubes.smooth(volume)
    # print("end smoothing")
    verts, faces = mcubes.marching_cubes(volume, -level)
    verts *= gaps
    verts -= qp_bounds
    verts += center

    if vert_out:
        return verts, faces

    mesh = o3d.geometry.TriangleMesh(vertices=o3d.utility.Vector3dVector(verts), triangles=o3d.utility.Vector3iVector(faces))
    mesh.compute_vertex_normals()


    if visualize:
        mesh_frame = o3d.geometry.TriangleMesh.create_coordinate_frame(size=0.2, origin=np.array([0., 0., 0.]))
        if latent_obj is not None:
            obj_pcd = o3d.geometry.PointCloud(o3d.utility.Vector3dVector(latent_obj.fps_tf.reshape(-1,3))).paint_uniform_color([1,0,0])
            o3d.visualization.draw_geometries([obj_pcd, mesh, mesh_frame])
        else:
            o3d.visualization.draw_geometries([mesh, mesh_frame])

    return mesh


def create_fps_fcd_from_oriCORNs(latent_obj:loutil.LatentObjects, sphere_radius=None, color=None, visualize=True):
    if sphere_radius is not None:
        sphere_radius = np.minimum(latent_obj.mean_fps_dist, sphere_radius)
        fps_tf = latent_obj.fps_tf.reshape(-1,3)
        mesh_list = []
        for i in range(fps_tf.shape[0]):
            sphere = o3d.geometry.TriangleMesh.create_sphere(radius=sphere_radius)
            sphere.compute_vertex_normals()
            sphere.translate(fps_tf[i])
            sphere.paint_uniform_color(color)
            mesh_list.append(sphere)
        if visualize:
            # mesh_frame = o3d.geometry.TriangleMesh.create_coordinate_frame(size=0.2, origin=np.array([0., 0., 0.]))
            # o3d.visualization.draw_geometries(mesh_list+[mesh_frame])
            o3d.visualization.draw_geometries(mesh_list)
        else:
            return mesh_list

    obj_pcd = o3d.geometry.PointCloud(o3d.utility.Vector3dVector(latent_obj.fps_tf.reshape(-1,3) if latent_obj.rel_fps is not None else latent_obj.pcd_tf.reshape(-1,3))).paint_uniform_color([1,0,0])
    if visualize:
        mesh_frame = o3d.geometry.TriangleMesh.create_coordinate_frame(size=0.2, origin=np.array([0., 0., 0.]))
        o3d.visualization.draw_geometries([obj_pcd, mesh_frame])
    return obj_pcd


def create_swept_volume_from_oriCORNs(latent_obj, line_segment, times, level=0.0, qp_bound=0.8, density=200, ndiv=200, center=None, visualize=True):
    translation_offset = -line_segment*times[...,None]
    latent_obj = latent_obj.replace(rel_fps=(latent_obj.fps_tf+translation_offset)[...,None,:], z= latent_obj.z[...,None,:,:]).init_pos_zero()
    latent_obj = latent_obj.reshape_outer_shape((-1,))
    line_segment = line_segment.reshape(-1,3)
    assert latent_obj.shape[0] == line_segment.shape[0]

    models = mutil.Models().load_pretrained_models()

    # @jax.jit
    def dec_func(dummpy, qpoints, jkey):
        original_outer_shape = qpoints.shape[:-2]
        qpoints = qpoints.reshape((-1,)+(3,))

        nq = qpoints.shape[-2]
        point_oriCORN = loutil.LatentObjects(rel_fps=jnp.zeros((1,3)).astype(jnp.float32), z=jnp.zeros((1,)+models.latent_shape[1:]).astype(jnp.float32)).init_pos_zero()
        point_oriCORN = point_oriCORN.extend_and_repeat_outer_shape(nq, 0)
        point_oriCORN = point_oriCORN.translate(qpoints)
        
        collision_loss_pair = jax.vmap(partial(models.apply, 'col_decoder', latent_obj_B=latent_obj, line_segment_B=line_segment,
                                                                            reduce_k=4, pairwise_out=True, jkey=jkey))(point_oriCORN)

        res = jnp.max(collision_loss_pair, axis=(-1,-2,-3,-4))
        res = res.reshape(original_outer_shape + res.shape[-1:])

        return res
    if center is None:
        center = jnp.mean(latent_obj.fps_tf.reshape(-1,3), axis=0)
    mesh = create_scene_mesh_from_oriCORNs(None, dec=jax.jit(dec_func), visualize=visualize, center=center, 
                                           level=level, density=density, ndiv=ndiv, qp_bound=qp_bound)
    
    return mesh


def sample_surface_pnts(jkey, oriCORN:loutil.LatentObjects, models:mutil.Models, nsp=1000):
    nsampler_per_fps = (nsp*3) // oriCORN.nfps

    mean_fps_dist = oriCORN.mean_fps_dist
    fps_tf = oriCORN.fps_tf
    
    surface_pnts_candidates = fps_tf[...,None,:] + jax.random.normal(jkey, fps_tf.shape[:-1] + (nsampler_per_fps,3))*(mean_fps_dist[...,None,None,None]/3.0)
    surface_pnts_candidates = einops.rearrange(surface_pnts_candidates, '... i j k -> ... (i j) k')
    # random permutation
    surface_pnts_candidates = jax.random.permutation(jkey, surface_pnts_candidates, axis=-2)
    jkey, _ = jax.random.split(jkey)
    occ_res = models.occ_prediction(oriCORN, surface_pnts_candidates, jkey)
    surface_mask = occ_res > -2.0
    surface_mask = surface_mask.squeeze(-1)

    # sample fixed number of surface points
    origin_outer_shape = surface_mask.shape[:-1]
    ncandidates = surface_mask.shape[-1]    
    surface_pnts_idx = jax.vmap(partial(jnp.where, size=nsp, fill_value=0))(surface_mask.reshape(-1, ncandidates))[0]
    surface_pnts_idx = surface_pnts_idx.astype(jnp.int32)
    surface_pnts = jnp.take_along_axis(surface_pnts_candidates.reshape(-1, ncandidates, 3), surface_pnts_idx[...,None], axis=-2)
    surface_pnts = surface_pnts.reshape(origin_outer_shape + (nsp, 3))

    return surface_pnts


def create_primitives(oriCORNs:loutil.LatentObjects, dec, level=0.0, visualize=False):

    if oriCORNs.ndim > 0:
        oriCORNs = oriCORNs.reshape_outer_shape((-1,))
        oriCORNs = oriCORNs.squeeze_outer_shape(axis=0)

    assert oriCORNs.ndim == 0

    AABB_target = oriCORNs.AABB_fps
    qp_bound = np.max(AABB_target[1] - AABB_target[0])*0.4
    qp_bound = np.array(qp_bound)

    o3d_mesh_list = []
    np_rng = np.random.default_rng(21)
    for i in range(oriCORNs.nfps):
        oriCORNs_primitives:loutil.LatentObjects = oriCORNs.replace(rel_fps=oriCORNs.rel_fps[...,i:i+1,:], z=oriCORNs.z[...,i:i+1,:,:])
        mesh_o3d = create_scene_mesh_from_oriCORNs(oriCORNs_primitives, dec=dec, qp_bound=qp_bound, level=level, density=200, ndiv=40, visualize=False)
        mesh_o3d.paint_uniform_color(np_rng.random(size=(3,)))
        o3d_mesh_list.append(mesh_o3d)
    
    # merge all meshes
    all_vertices = []
    all_vertex_colors = []
    all_triangles = []
    offset = 0
    for mesh in o3d_mesh_list:
        vertices = np.asarray(mesh.vertices)
        colors = np.asarray(mesh.vertex_colors)
        triangles = np.asarray(mesh.triangles)
        all_vertices.append(vertices)
        all_vertex_colors.append(colors)
        all_triangles.append(triangles + offset)
        offset += vertices.shape[0]
    merged_mesh = o3d.geometry.TriangleMesh(
        vertices=o3d.utility.Vector3dVector(np.vstack(all_vertices)),
        triangles=o3d.utility.Vector3iVector(np.vstack(all_triangles))
    )
    if all_vertex_colors[0].shape[0] > 0:
        merged_mesh.vertex_colors = o3d.utility.Vector3dVector(np.vstack(all_vertex_colors))
    merged_mesh.compute_vertex_normals()

    if visualize:
        # mesh_frame = o3d.geometry.TriangleMesh.create_coordinate_frame(size=0.2, origin=np.array([0., 0., 0.]))
        # o3d.visualization.draw_geometries(o3d_mesh_list+[mesh_frame])
        o3d.visualization.draw_geometries([merged_mesh])

    return merged_mesh

def stream_oriCORNs(oriCORNs:loutil.LatentObjects, models, colors=None, depths=None, cam_intrinsics=None, cam_pqcs=None, carry=None):

    if oriCORNs.ndim==2:
        oriCORNs = oriCORNs.reshape_outer_shape((-1,))
        # oriCORNs = oriCORNs.squeeze_outer_shape(axis=0)

    if carry is None:
        surface_pnts_func = jax.jit(partial(sample_surface_pnts, models=models, nsp=400))

        vis_o3d = o3d.visualization.Visualizer()
        vis_o3d.create_window()

        axis = o3d.geometry.TriangleMesh.create_coordinate_frame(0.2)
        vis_o3d.add_geometry(axis)
        
        oriCORNs_pcds = []
        for i in range(oriCORNs.nobj):
            oriCORNs_pcds.append(o3d.geometry.PointCloud())
            vis_o3d.add_geometry(oriCORNs_pcds[i])
        
        cams = []
        if cam_intrinsics is not None:
            for i in range(cam_intrinsics.shape[0]):
                Tm = tutil.pq2H(cam_pqcs[i,:3], cam_pqcs[i,3:])
                mesh_frame = o3d.geometry.TriangleMesh.create_coordinate_frame(
                    size=0.1, origin=[0, 0, 0])
                mesh_frame.transform(np.array(Tm))
                cams.append(mesh_frame)
                vis_o3d.add_geometry(cams[i])

        if colors is not None:
            env_pcd = o3d.geometry.PointCloud()
            vis_o3d.add_geometry(env_pcd)
        else:
            env_pcd = None
    else:
        vis_o3d, oriCORNs_pcds, env_pcd, cams, surface_pnts_func = carry

    surface_pnts = surface_pnts_func(jax.random.PRNGKey(11), oriCORNs)
    # surface_pnts = oriCORNs.fps_tf
    colors_map = np.random.default_rng(21).random(size=(100, 3))
    for i in range(oriCORNs.nobj):
        oriCORNs_pcds[i].points = o3d.utility.Vector3dVector(surface_pnts[i].reshape(-1,3))
        oriCORNs_pcds[i].colors = o3d.utility.Vector3dVector(colors_map[i][None].repeat(surface_pnts.shape[-2], axis=0))
        vis_o3d.update_geometry(oriCORNs_pcds[i])
    
    if colors is not None and depths is not None:
        env_pcd_new = cutil.np2o3d_img2pcd(colors[0], depths[0], np.array(cam_intrinsics[0]), np.array(cam_pqcs[0]))
        env_pcd.points = env_pcd_new.points
        env_pcd.colors = env_pcd_new.colors
        vis_o3d.update_geometry(env_pcd)
    else:
        env_pcd = None

    if cam_intrinsics is not None:
        for i in range(cam_intrinsics.shape[0]):
            Tm = tutil.pq2H(cam_pqcs[i,:3], cam_pqcs[i,3:])
            mesh_frame = o3d.geometry.TriangleMesh.create_coordinate_frame(
                size=0.1, origin=[0, 0, 0])
            mesh_frame.transform(np.array(Tm))
            cams[i].vertices = mesh_frame.vertices
            vis_o3d.update_geometry(cams[i])

    for _ in range(4):
        vis_o3d.poll_events()
        vis_o3d.update_renderer()

    return vis_o3d, oriCORNs_pcds, env_pcd, cams, surface_pnts_func

if __name__ == '__main__':
    import pickle
    load_estimation_model_dirname = 'dif_ckpt/dim160_fm_reg/saved.pkl'
    # load_estimation_model_dirname = 'dif_ckpt/dim160_fm_reg_cam/saved.pkl'
    with open(load_estimation_model_dirname, 'rb') as f:
        estimation_model = pickle.load(f)

    models:mutil.Models = estimation_model['models']
    models = models.set_params(estimation_model['ema_params'])
    del estimation_model
    models = models.load_dino_params()

    jkey = jax.random.PRNGKey(1)
    target_oriCORN = models.canonical_latent_obj[0]
    surface_pnts = sample_surface_pnts(jkey, target_oriCORN, models, nsp=1000)

    # visualize surface pnts
    obj_pcd = o3d.geometry.PointCloud(o3d.utility.Vector3dVector(surface_pnts))
    mesh_frame = o3d.geometry.TriangleMesh.create_coordinate_frame(size=0.2, origin=np.array([0., 0., 0.]))
    o3d.visualization.draw_geometries([obj_pcd, mesh_frame])

    print(1)