import jax.numpy as jnp
import open3d as o3d
import numpy as np
import jax
import optax
from tqdm import tqdm
import pickle
import open3d as o3d
from typing import Optional

import os, sys
BASEDIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
if BASEDIR not in sys.path:
    sys.path.insert(0, BASEDIR)

import util.model_util as mutil
import util.latent_obj_util as loutil
import util.reconstruction_util as rcutil
import dataset.generate_col_data_from_mesh as gencol
import util.structs as structs
import util.transform_util as tutil

def path_col_datagen(col_datapoints:structs.ColDataset, oriCORN_A:loutil.LatentObjects, oriCORN_B:loutil.LatentObjects, jkey, debug=False):

    npath = 5
    col_gt = col_datapoints.col_gt
    min_direction = col_datapoints.min_direction.squeeze(-2)
    outer_shape = min_direction.shape[:-1] # (NB, )

    jkey, subkey = jax.random.split(jkey)
    path_direction = jax.random.normal(subkey, shape=outer_shape + (2, 3))
    path_direction_perp = jnp.cross(min_direction[...,None, :], path_direction)
    path_direction = jnp.where(col_gt[...,None], path_direction, path_direction_perp)
    path_direction = path_direction/jnp.linalg.norm(path_direction, axis=-1, keepdims=True) # (NB, 2, 3)

    jkey, subkey = jax.random.split(jkey)
    path_len = jax.random.uniform(subkey, shape=outer_shape+(2, 1,), minval=0.05, maxval=0.5)
    jkey, subkey = jax.random.split(jkey)
    path_distance = jax.random.uniform(subkey, shape=outer_shape+(2, npath,)) * path_len # (NB, 2, npath)
    jkey, subkey = jax.random.split(jkey)
    random_idx = jax.random.randint(subkey, shape=outer_shape+(2, 1,), minval=0, maxval=path_distance.shape[-1])
    path_distance = path_distance - jnp.take_along_axis(path_distance, random_idx, axis=-1) # (NB, 2, npath) # centering with distance
    path_ptb = path_direction[...,None,:] * path_distance[...,None] # (NB, 2, npath, 3)

    oriCORN_A = oriCORN_A.extend_and_repeat_outer_shape(npath, -1)
    oriCORN_A = oriCORN_A.translate(path_ptb[:,0])
    oriCORN_B = oriCORN_B.extend_and_repeat_outer_shape(npath, -1)
    oriCORN_B = oriCORN_B.translate(path_ptb[:,1])

    if debug:
        # visualize path
        import open3d as o3d
        for i in range(outer_shape[0]):
            for obj in [oriCORN_A, oriCORN_B]:
                for k in range(npath):
                    path_pcd = o3d.geometry.PointCloud()
                    path_pcd.points = o3d.utility.Vector3dVector(np.array(obj[i,k].fps_tf))
            coordinate_frame = o3d.geometry.TriangleMesh.create_coordinate_frame(size=1.2, origin=np.array([0., 0., 0.]))
            o3d.visualization.draw_geometries([path_pcd, coordinate_frame], point_show_normal=True)

    return oriCORN_A, oriCORN_B


def inside_check(scene, pnts, cnt_check=True):
    if cnt_check:
        ray_direction = np.random.normal(size=pnts.shape)
        ray_direction = ray_direction/np.linalg.norm(ray_direction, axis=-1,keepdims=True)
        rays = np.concatenate([pnts, ray_direction], axis=-1)
        rays = o3d.core.Tensor(rays, dtype=o3d.core.Dtype.Float32)
        intersection_cnt = scene.count_intersections(rays)
        intersection_cnt = intersection_cnt.numpy()
        inside = (intersection_cnt%2==1)
    else:
        ray_sample = 100
        ray_direction = np.random.normal(size=pnts.shape[:-1] + (ray_sample,) + pnts.shape[-1:])
        ray_direction = ray_direction/np.linalg.norm(ray_direction, axis=-1,keepdims=True)
        rays = pnts[...,None,:].repeat(ray_sample, -2)
        rays = np.concatenate([rays.reshape(-1,3), ray_direction.reshape(-1,3)], axis=-1)
        rays = o3d.core.Tensor(rays, dtype=o3d.core.Dtype.Float32)
        intersection_cnt = scene.count_intersections(rays)
        intersection_cnt = intersection_cnt.numpy()
        intersection_cnt = intersection_cnt.reshape(-1, ray_sample)
        inside = np.sum(intersection_cnt>=1, -1) >= ray_sample-1
    return inside


TRUNCATED_VAL = 0.03
def create_obj_dataset(mesh_file_name, nquery=80000, visualize=False, robust_sampling=True, surface_pnt_robust_sampling=True):
    mesh_legacy = o3d.io.read_triangle_mesh(mesh_file_name)
    mesh_legacy.compute_vertex_normals()

    min_bound = np.asarray(mesh_legacy.vertices).min(0)
    max_bound = np.asarray(mesh_legacy.vertices).max(0)
    cen_bound = (min_bound + max_bound)/2
    max_len = np.max(max_bound - min_bound)
    assert max_len > 1e-7, f"max_len is too small: {max_len} / {mesh_file_name}"
    scale = (2.0/max_len)*0.9
    mesh_legacy = mesh_legacy.translate(-cen_bound)
    mesh_legacy = mesh_legacy.scale(scale, np.zeros(3))

    meta_data = {
        'path':mesh_file_name,
        'file_name':os.path.basename(mesh_file_name),
        'translation':-cen_bound,
        'scale':scale,
    }

    if visualize:
        mesh_legacy.compute_vertex_normals()
        o3d.visualization.draw_geometries([mesh_legacy])
    enlarged_dssize = int(nquery*2.0)
    surface_npnt = int(enlarged_dssize*0.8)
    surface_pcd = mesh_legacy.sample_points_uniformly(number_of_points=int(4*surface_npnt) if (robust_sampling or surface_pnt_robust_sampling) else surface_npnt)
    surface_pnts = np.array(surface_pcd.points).astype(np.float32)
    surface_nmls = np.array(surface_pcd.normals).astype(np.float32)
    permutation_idx = np.random.permutation(np.arange(surface_pnts.shape[0]))
    surface_pnts = surface_pnts[permutation_idx]
    surface_nmls = surface_nmls[permutation_idx]

    # coordinate_frame = o3d.geometry.TriangleMesh.create_coordinate_frame(size=1.2, origin=np.array([0., 0., 0.]))
    # o3d.visualization.draw_geometries([surface_pcd, coordinate_frame])

    mesh = o3d.t.geometry.TriangleMesh.from_legacy(mesh_legacy)

    # Create a scene and add the triangle mesh
    scene = o3d.t.geometry.RaycastingScene()
    _ = scene.add_triangles(mesh)  # we do not need the geometry ID for mesh

    if surface_pnt_robust_sampling:
        res = inside_check(scene, surface_pnts + surface_nmls*0.001, cnt_check=False)
        # assert np.sum(res) < res.shape[0] - surface_npnt
        surface_pnts = surface_pnts[np.logical_not(res)][:surface_npnt]
        surface_nmls = surface_nmls[np.logical_not(res)][:surface_npnt]

        if visualize:
            surface_pcd.points = o3d.utility.Vector3dVector(surface_pnts)
            surface_pcd.normals = o3d.utility.Vector3dVector(surface_nmls)
            mesh_legacy.compute_vertex_normals()
            o3d.visualization.draw_geometries([surface_pcd, mesh_legacy], point_show_normal=True)

    min_bound = mesh.vertex.positions.min(0).numpy()
    max_bound = mesh.vertex.positions.max(0).numpy()

    query_points_uniform = np.random.uniform(np.array([-1.2,-1.2,-1.2]), np.array([1.2,1.2,1.2]), size=(int(enlarged_dssize*0.2),3)).astype(np.float32)
    close_npnt = int(surface_pnts.shape[0]*0.625)
    far_npnt = surface_pnts.shape[0] - close_npnt
    query_points_surface1 = surface_pnts[:close_npnt] + np.random.normal(size=(close_npnt,3))*TRUNCATED_VAL*0.4
    query_points_surface2 = surface_pnts[close_npnt:] + np.random.normal(size=(far_npnt,3))*TRUNCATED_VAL*2.0

    query_points = np.concatenate([query_points_surface1, query_points_surface2, query_points_uniform],0).astype(np.float32)
    query_points = np.random.permutation(query_points)

    signed_distance = scene.compute_signed_distance(query_points)
    signed_distance = signed_distance.numpy()

    if robust_sampling:
        inside = inside_check(scene, query_points, cnt_check=False)
        signed_distance = np.abs(signed_distance)
        signed_distance[inside] *= -1

    # balancing with signed_distance
    sign = signed_distance < 0
    prob = np.where(sign, np.arange(sign.shape[-1]) * np.sum(sign), np.arange(sign.shape[-1]) * np.sum(1-sign))
    idx = np.argsort(prob, -1)[...,:nquery]
    query_points = query_points[idx]
    signed_distance = signed_distance[idx]

    # filter
    print(f"inside npnts: {np.sum(signed_distance<0)} / {mesh_file_name}")

    assert np.sum(signed_distance<0) > nquery*0.15, f"inside npnts is too small: {np.sum(signed_distance<0)} / {mesh_file_name}"

    ans = scene.compute_closest_points(query_points)
    closest_points = ans['points'].numpy()
    normal_vector = np.sign(signed_distance)[...,None]*(query_points-closest_points)

    query_points = np.array(query_points.reshape(-1,3))
    signed_distance = np.array(signed_distance.reshape(-1,1))
    normal_vector = np.array(normal_vector.reshape(-1,3))
    normal_vector = normal_vector/(np.linalg.norm(normal_vector, axis=-1, keepdims=True)+1e-6)
    
    if visualize:
        # visualization code
        root_pcd = o3d.geometry.PointCloud()
        root_pcd.points = o3d.utility.Vector3dVector(np.array(query_points[signed_distance.squeeze(-1)<0]))
        root_pcd.normals = o3d.utility.Vector3dVector(np.array(normal_vector[signed_distance.squeeze(-1)<0]))
        coordinate_frame = o3d.geometry.TriangleMesh.create_coordinate_frame(size=1.2, origin=np.array([0., 0., 0.]))
        o3d.visualization.draw_geometries([root_pcd, coordinate_frame], point_show_normal=True)

        root_pcd = o3d.geometry.PointCloud()
        root_pcd.points = o3d.utility.Vector3dVector(np.array(query_points[signed_distance.squeeze(-1)>=0]))
        root_pcd.normals = o3d.utility.Vector3dVector(np.array(normal_vector[signed_distance.squeeze(-1)>=0]))
        coordinate_frame = o3d.geometry.TriangleMesh.create_coordinate_frame(size=1.2, origin=np.array([0., 0., 0.]))
        o3d.visualization.draw_geometries([root_pcd, coordinate_frame], point_show_normal=True)
        # visualization code

    # # ray queries
    ray_origin = np.concatenate([query_points, query_points], axis=-2)
    ray_direction = np.random.normal(size=ray_origin.shape)
    ray_direction = ray_direction/np.linalg.norm(ray_direction, axis=-1,keepdims=True)
    ray_origin = ray_origin - 2.0*ray_direction
    
    rays_np = np.concatenate([ray_origin, ray_direction], axis=-1)
    rays = o3d.core.Tensor(rays_np, dtype=o3d.core.Dtype.Float32)
    ans = scene.cast_rays(rays)
    t_hit = ans['t_hit'].numpy()
    # t_hit = np.where(np.isinf(t_hit), -1, t_hit)

    # # balancing
    # sign = t_hit < 0
    sign = np.isinf(t_hit)
    prob = np.where(sign, np.arange(sign.shape[-1]) * np.sum(sign), np.arange(sign.shape[-1]) * np.sum(1-sign))
    idx = np.argsort(prob, -1)[...,:query_points.shape[-2]]
    rays_np = rays_np[idx]
    t_hit = t_hit[idx]

    ray_points = rays_np[...,:3]

    # recover to original scale and translation
    # query_points = (query_points/scale) + cen_bound
    # surface_pnts = (surface_pnts/scale) + cen_bound
    # signed_distance = signed_distance/scale
    # t_hit = t_hit/scale
    # ray_points = rays_np[...,:3]/scale + cen_bound

    t_hit = np.where(np.isinf(t_hit), -1, t_hit)

    meta_data['query_points'] = query_points
    meta_data['signed_distance'] = signed_distance
    meta_data['normal_vector'] = normal_vector
    meta_data['surface_points'] = surface_pnts
    
    meta_data['ray_points'] = ray_points
    meta_data['ray_directions'] = 10*rays_np[...,3:]
    meta_data['ray_hitting_gt'] = t_hit

    if visualize:
        # # visualization #
        inside_query_pcd = o3d.geometry.PointCloud()
        inside_query_pcd.points = o3d.utility.Vector3dVector(np.array(meta_data['query_points'][meta_data['signed_distance'].reshape(-1)<0]))

        # root_pcd = o3d.geometry.PointCloud()
        # root_pcd.points = o3d.utility.Vector3dVector(np.array(meta_data['surface_points'].reshape(-1,3)))
        pick_idx = np.where(meta_data['ray_hitting_gt']<0)[0][:20]
        inside_rays = rays[pick_idx]
        ray_start_pnt = meta_data['ray_points'][pick_idx]
        ray_end_pnt = ray_start_pnt + meta_data['ray_directions'][pick_idx]
        ray_pnts = np.concatenate([ray_start_pnt, ray_end_pnt], axis=-2)
        ray_lineset = o3d.geometry.LineSet()
        ray_lineset.points = o3d.utility.Vector3dVector(ray_pnts.reshape(-1,3))
        ray_lineset.lines = o3d.utility.Vector2iVector(np.stack([np.arange(ray_pnts.shape[-2]//2), np.arange(ray_pnts.shape[-2]//2)+ray_pnts.shape[-2]//2], axis=-1))
        coordinate_frame = o3d.geometry.TriangleMesh.create_coordinate_frame(size=1.2, origin=np.array([0., 0., 0.]))
        o3d.visualization.draw_geometries([inside_query_pcd, coordinate_frame, ray_lineset], point_show_normal=True)
        # # visualization #

    return meta_data


def encode_mesh(models:mutil.Models, mesh_filename:str=None, train_data=None, nquery:int=500000, 
                nfps_multiplier:int=1, batch_size:int=1024, niter:int=1000, use_col_loss:bool=True, visualize=False):

    mesh_normalized = False
    if train_data is None:
        mesh_normalized = True
        train_data = create_obj_dataset(mesh_filename, nquery=nquery, visualize=visualize, robust_sampling=True, surface_pnt_robust_sampling=True)

    if mesh_filename is not None:
        dataset_name = os.path.basename(os.path.dirname(os.path.dirname(mesh_filename)))
        pretrain_model_id = models.pretrain_ckpt_id

        if use_col_loss:
            output_col_filename = f'assets_oriCORNs/{pretrain_model_id}/{dataset_name}/{os.path.basename(mesh_filename).split(".")[0]}_col_data.pkl'
            if os.path.exists(output_col_filename):
                with open(output_col_filename, 'rb') as f:
                    col_data_total, sdf_mesh_paths = pickle.load(f)
                print(f"loaded: {output_col_filename}")
            else:
                gen_col_data_cls = gencol.ColDataGen()
                col_data_total, sdf_mesh_paths = gen_col_data_cls.gen_col_data_with_mesh(models, mesh_filename, create_obj_dataset, ndata=20000)
                col_data_total = jax.tree_util.tree_map(lambda x: jnp.array(x), col_data_total)
                os.makedirs(os.path.dirname(output_col_filename), exist_ok=True)
                with open(output_col_filename, 'wb') as f:
                    pickle.dump((col_data_total, sdf_mesh_paths), f)
                print(f"saved: {output_col_filename}")
            # gen_col_data_cls = gencol.ColDataGen()
            # col_data_total, sdf_mesh_paths = gen_col_data_cls.gen_col_data_with_mesh(models, mesh_filename, create_obj_dataset, ndata=2, visualize=visualize)
            # col_data_total = jax.tree_util.tree_map(lambda x: jnp.array(x), col_data_total)
            
            obj_ids = []
            for sdf_mesh_path in sdf_mesh_paths[:-1]:
                obj_ids.append(models.asset_path_util.get_obj_id(sdf_mesh_path))

                # with open(sdf_mesh_path, 'rb') as f:
                #     sdf_data = pickle.load(f)
                # surface_pnts = sdf_data['surface_points']
                # test_obj = models.canonical_latent_obj[obj_ids[-1]]
                # pcd2 = o3d.geometry.PointCloud()
                # pcd2.points = o3d.utility.Vector3dVector(np.array(surface_pnts))
                # pcd2.paint_uniform_color([0,1,0])
                # pcd1 = o3d.geometry.PointCloud()
                # pcd1.points = o3d.utility.Vector3dVector(np.array(test_obj.fps_tf))
                # pcd1.paint_uniform_color([1,0,0])
                # o3d.visualization.draw_geometries([pcd2, pcd1])

            oriCORNs_col_data = models.canonical_latent_obj[np.array(obj_ids)]
        else:
            col_data_total = None
            oriCORNs_col_data = None
    else:
        col_data_total = None
        oriCORNs_col_data = None


    query_points = jnp.array(train_data['query_points'])
    signed_distance_gt = jnp.array(train_data['signed_distance'])
    
    apply_ray_loss = False
    if 'ray_points' in train_data:
        apply_ray_loss = True
        ray_points = jnp.array(train_data['ray_points'])
        ray_directions = jnp.array(train_data['ray_directions'])
        ray_hitting_gt = jnp.array(train_data['ray_hitting_gt'])
    
    surface_pnts = train_data['surface_points']   
    latent_shape = models.latent_shape
    fps_pnts = rcutil.FPS_padding(surface_pnts, latent_shape[0]*nfps_multiplier, jax.random.PRNGKey(0))

    z = jax.random.normal(jax.random.PRNGKey(0), shape=(latent_shape[0]*nfps_multiplier, *latent_shape[1:]))*0.01
    latent_obj = loutil.LatentObjects(pos=jnp.mean(fps_pnts, axis=-2), rel_fps=None, z=z).set_fps_tf(fps_pnts)
    
    jax.grad(lambda x, qpnts, jkey: models.occ_prediction(latent_obj.replace(z=x), qpnts, jkey, train=True))

    def loss_func(z, qpnts, sdf_gt, jkey, ray_pnts=None, ray_dirs=None, ray_h_gt=None, col_data_batch:Optional[structs.ColDataset]=None):
        
        if col_data_batch is not None:
            latent_obj_A: loutil.LatentObjects = oriCORNs_col_data[col_data_batch.obj_idx[:,0]]
            latent_obj_A = latent_obj_A.init_pos_zero()
            latent_obj_A = latent_obj_A.apply_scale(col_data_batch.obj_scale[:,0])
            latent_obj_A = latent_obj_A.apply_pq_z(col_data_batch.obj_pos[:,0], col_data_batch.obj_quat[:,0], models.rot_configs)
            latent_obj_A = jax.lax.stop_gradient(latent_obj_A)

            latent_obj_B:loutil.LatentObjects = latent_obj.replace(z=z)
            latent_obj_B = latent_obj_B.apply_scale(col_data_batch.obj_scale[:,1], center=jnp.zeros(3))
            latent_obj_B = latent_obj_B.apply_pq_z(col_data_batch.obj_pos[:,1], col_data_batch.obj_quat[:,1], models.rot_configs)

            if visualize:
                for ii in range(latent_obj_A.shape[0]):
                    objB_o3d = o3d.io.read_triangle_mesh(mesh_filename)
                    objB_o3d.compute_vertex_normals()
                    objB_o3d.scale(col_data_batch.obj_scale[ii,1,0,0], np.zeros(3))
                    objB_o3d.transform(tutil.pq2H(col_data_batch.obj_pos[ii,1], col_data_batch.obj_quat[ii,1]))

                    cur_col_gt = col_data_batch.col_gt[ii].squeeze(-1)
                    print(f"col_gt: {cur_col_gt}")
                    print(col_data_batch.obj_scale[ii,0])
                    obj_o3d_A = latent_obj_A[ii].get_fps_o3d(color=[1,0,0] if cur_col_gt else [0,0,1])
                    obj_o3d_B = latent_obj_B[ii].get_fps_o3d(color=[0,1,0])
                    o3d.visualization.draw_geometries([objB_o3d, obj_o3d_A, obj_o3d_B])

        bce_loss_entire = 0
        col_loss_entire = 0
        # for reduce_k in [8, 16, 32]:
        for reduce_k in [8, 16]:
            jkey, subkey = jax.random.split(jkey)
            occ_res = models.occ_prediction(latent_obj.replace(z=z), qpnts, subkey, reduce_k=reduce_k, train=True)
            occ_pred = jax.nn.sigmoid(occ_res)
            occ_pred = occ_pred.clip(1e-7, 1-1e-7)
            occ_gt = (sdf_gt < 0).astype(jnp.float32)
            assert occ_gt.shape == occ_pred.shape
            bce_loss = -occ_gt*jnp.log(occ_pred) - (1-occ_gt)*jnp.log(1-occ_pred)
            bce_loss = jnp.mean(bce_loss)
            bce_loss_entire += bce_loss

            # occ accuracy
            occ_acc = jnp.mean((occ_pred > 0.5) == occ_gt)

            if apply_ray_loss:
                ray_res = models.ray_prediction(latent_obj.replace(z=z), ray_pnts, ray_dirs, jkey, reduce_k=reduce_k, depth_multiplier=1.0, train=True)[0]
                ray_pred = jax.nn.sigmoid(ray_res.squeeze(-1))
                ray_pred = ray_pred.clip(1e-7, 1-1e-7)
                ray_gt = (ray_h_gt > 0).astype(jnp.float32)
                assert ray_gt.shape == ray_pred.shape
                ray_bce_loss = -ray_gt*jnp.log(ray_pred) - (1-ray_gt)*jnp.log(1-ray_pred)
                ray_bce_loss = jnp.mean(ray_bce_loss)
                bce_loss_entire += ray_bce_loss

                ray_acc = jnp.mean((ray_pred > 0.5) == ray_gt)
            
            if col_data_batch is not None:
                # apply col loss
                col_logits = models.apply('col_decoder', latent_obj_A, latent_obj_B, jkey=jkey, reduce_k=reduce_k, train=True)[0]
                col_pred = jax.nn.sigmoid(col_logits)
                col_pred = col_pred.clip(1e-7, 1-1e-7)
                col_gt = col_data_batch.col_gt.astype(jnp.float32) # binary
                assert col_gt.shape == col_logits.shape
                col_loss = -col_gt*jnp.log(col_pred) - (1-col_gt)*jnp.log(1-col_pred)
                col_loss = jnp.mean(col_loss)
                col_loss_entire += col_loss

                col_acc = jnp.mean((col_pred > 0.5) == col_gt)

                # path col loss
                jkey, subkey = jax.random.split(jkey)
                path_obj_A, path_obj_B = path_col_datagen(col_data_batch, latent_obj_A, latent_obj_B, jkey)
                line_segment_B = path_obj_B[:,-1].pos - path_obj_B[:,0].pos

                jkey, subkey = jax.random.split(jkey)

                path_col_logits = models.apply('col_decoder', latent_obj_A, 
                                                path_obj_B[:,0], line_segment_B=line_segment_B, 
                                                reduce_k=reduce_k, jkey=subkey, train=True)[0]

                path_col_gt = col_data_batch.col_gt.astype(jnp.float32) # binary    
                path_col_pred = jax.nn.sigmoid(path_col_logits)
                path_col_pred = path_col_pred.clip(1e-7, 1-1e-7)
                path_col_loss = -path_col_gt*jnp.log(path_col_pred) - (1-path_col_gt)*jnp.log(1-path_col_pred)
                path_col_loss = jnp.mean(path_col_loss)
                col_loss_entire += path_col_loss

                # col_loss_entire = jax.lax.stop_gradient(col_loss_entire) # test!! should be removed

                path_col_acc = jnp.mean((path_col_pred > 0.5) == path_col_gt)
            else:
                col_acc = 0
                path_col_acc = 0
        
        # reg loss

        # add regularization loss - mean / variance of z - covariance regularizer
        z_flat_vec = z.reshape(-1)
        # z_flat_norm = jnp.mean(jnp.linalg.norm(z_flat_vec, axis=-1))
        # N, d = z_flat_vec.shape
        # z_mean = jnp.mean(z_flat_vec, axis=0)
        #     # Center the vectors by subtracting the mean
        # # centered_vectors = z_flat_vec - z_mean  # Shape: (N, d)
        # centered_vectors = z_flat_vec # Shape: (N, d)
        # # Compute the covariance matrix (N-1 normalization for unbiased estimate)
        # cov_matrix = jnp.einsum('ij,ik->jk', centered_vectors, centered_vectors) / (N - 1)  # Shape: (d, d)
        # # Compute the Frobenius norm of the difference between covariance and identity matrix
        # cov_reg_loss = jnp.sum((cov_matrix - jnp.eye(d)) ** 2)
        # mean_reg_loss = jnp.sum(z_mean**2)
        # reg_loss = 0.0000001*jnp.sum(z_flat_vec**2)
        # reg_loss = 1e-5*jnp.sum(z_flat_vec**2)
        # reg_loss = cov_reg_loss + mean_reg_loss
        reg_loss = 0

        return bce_loss_entire + col_loss_entire + reg_loss, {'bce_loss':bce_loss_entire, 'col_loss':col_loss_entire, 'ray_bce_loss':ray_bce_loss,
                                                   'occ_acc':occ_acc, 'ray_acc':ray_acc,
                                                   'col_acc':col_acc, 'path_col_acc':path_col_acc, 'reg_loss': reg_loss}
    
    loss_grad_func = jax.grad(loss_func, has_aux=True)
    # optimizer = optax.adam(1e-3)
    optimizer = optax.adamw(3e-4)

    def train_step(z, opt_state, jkey):
        loss_aux = None
        num_inner_itr = 4
        for _ in range(num_inner_itr):
            jkey, jkey2, jkey3, jkey4 = jax.random.split(jkey, 4)
            random_idx = jax.random.choice(jkey2, jnp.arange(query_points.shape[0]), shape=(batch_size,), replace=False)
            qpnts = query_points[random_idx]
            sdf_gt = signed_distance_gt[random_idx]

            if col_data_total is not None:
                random_idx_col = jax.random.choice(jkey3, jnp.arange(len(col_data_total)), shape=(np.maximum(batch_size//64,1),), replace=False)
                col_data_batch = jax.tree_util.tree_map(lambda x: x[random_idx_col], col_data_total)
            else:
                col_data_batch = None
            
            if apply_ray_loss:
                random_idx = jax.random.choice(jkey2, jnp.arange(query_points.shape[0]), shape=(batch_size,), replace=False)
                ray_pnts = ray_points[random_idx]
                ray_dirs = ray_directions[random_idx]
                ray_h_gt = ray_hitting_gt[random_idx]
            else:
                ray_pnts = None
                ray_dirs = None
                ray_h_gt = None

            grad, loss_aux_ = loss_grad_func(z, qpnts, sdf_gt, jkey4, ray_pnts, ray_dirs, ray_h_gt, col_data_batch)
            if loss_aux is None:
                loss_aux = loss_aux_
            else:
                loss_aux = jax.tree_util.tree_map(lambda x, y: x+y, loss_aux, loss_aux_)
            updates, opt_state = optimizer.update(grad, opt_state, z)
            z = optax.apply_updates(z, updates)
        loss_aux = jax.tree_util.tree_map(lambda x: x/num_inner_itr, loss_aux)
        return z, opt_state, jkey, loss_aux


    if not visualize:
        train_step = jax.jit(train_step)

    opt_state = optimizer.init(z)

    jkey = jax.random.PRNGKey(0)
    for i in tqdm(range(niter)):
        z, opt_state, jkey, loss_aux = train_step(z, opt_state, jkey)
        if i%100==0:
            print(f"iter: {i}, loss: {loss_aux}")

    latent_obj = latent_obj.replace(z=z)


    if nfps_multiplier > 1:
        latent_obj = latent_obj.replace(z=z.reshape((nfps_multiplier, *latent_shape))).set_fps_tf(fps_pnts.reshape((nfps_multiplier, latent_shape[0], 3)))
        latent_obj = latent_obj.replace(pos=latent_obj.pos[None].repeat(nfps_multiplier, axis=0))
    
    if mesh_normalized:
        latent_obj = latent_obj.apply_scale(1/train_data['scale'], center=jnp.zeros(3))
        latent_obj = latent_obj.translate(-train_data['translation'])

    if mesh_filename is not None:
        dataset_name = os.path.basename(os.path.dirname(os.path.dirname(mesh_filename)))
        pretrain_model_id = models.pretrain_ckpt_id
        output_filename = f'assets_oriCORNs/{pretrain_model_id}/{dataset_name}/{os.path.basename(mesh_filename).split(".")[0]}.pkl'
    else:
        output_filename = 'assets_oriCORNs/tmp.pkl'
    os.makedirs(os.path.dirname(output_filename), exist_ok=True)
    with open(output_filename, 'wb') as f:
        pickle.dump(latent_obj, f)
    print(f"saved: {output_filename}")

    return latent_obj

if __name__ == '__main__':

    # np.random.seed(0)

    import pickle

    models = mutil.Models().load_pretrained_models()
    # models.col_decoder_model.dropout = 0.85
    # models.col_decoder_model.dropout

    # mesh_name = models.asset_path_util.get_obj_path_from_rel_path('kitchen/table/big_table_1.obj')
    # mesh_name = 'assets/ur5/meshes/rg6_gripper/modified/ee_rg6_gripper.obj'
    # mesh_name = 'assets/room/raw/wall.obj'
    # mesh_name = 'assets/room/raw/wall_no_floor.obj'
    # mesh_name = 'assets/room/raw/room_level2.obj'
    # mesh_name = 'assets/room/raw/room_no_floor_v2.obj'
    # mesh_name = 'assets/room/raw/room_no_floor_v3.obj'
    # mesh_name = 'assets/assembly/raw/hole_v3.obj'
    # mesh_name = 'assets/assembly/raw/peg_v3.obj'
    # mesh_name = 'assets/construction_site/raw/construction_site_moving_obj.obj'
    # mesh_name = 'assets/construction_site/raw/construction_site_moving_obstacles.obj'
    # mesh_name = 'assets/construction_site/raw/construction_site_obstacles_v4.obj'
    # mesh_name = 'assets/construction_site/raw/construction_site_shelf_v4.obj'
    # mesh_name = 'assets/construction_site/raw/construction_site_moving_obj_v4.obj'
    # mesh_name = 'assets/construction_site/raw/construction_site_obstacle_cart_v6.obj'
    # mesh_name = 'assets/construction_site/raw/construction_site_obstacle_beam.obj'
    # mesh_name = 'assets/construction_site/cvx/coacd/construction_site_obstacle_beam.obj'
    # mesh_name = 'assets/construction_site/cvx/coacd/construction_site_obstacle_beam_v5.obj'
    # mesh_name = 'assets/construction_site/raw/construction_site_obstacle_shelf_v5.obj'
    # mesh_name = 'assets/construction_site/raw/construction_site_obstacle_pipe_v5.obj'
    # mesh_name = 'assets/assembly/raw/peg_v5.obj'
    # mesh_name = 'assets/assembly/raw/hole_v5.obj'
    # mesh_name = 'assets/construction_site/cvx/coacd/construction_site_obstacle_cart_v6.obj'
    # mesh_name = 'assets/construction_site/cvx/coacd/construction_site_obstacle3_roadblock.obj'
    # mesh_name = 'assets/construction_site/cvx/coacd/construction_site_pickaxe.obj'
    # mesh_name = 'assets/construction_site/cvx/coacd/construction_site_obstacle1_beam.obj'
    # mesh_name = 'assets/construction_site/cvx/coacd/construction_site_obstacle1_roadcone.obj'
    # mesh_name = 'assets/construction_site/cvx/coacd/construction_site_obstacle2_drum.obj'
    # mesh_name = 'assets/construction_site/cvx/coacd/construction_site_pickaxe_v4.obj'
    # oriCORN = encode_mesh(models, mesh_name, nfps_multiplier=1, niter=10000, use_col_loss=False, visualize=False)

    mesh_names = [
        # 'assets/construction_site/cvx/coacd/construction_site_obstacle_cart_v6.obj',
        # 'assets/construction_site/cvx/coacd/construction_site_obstacle3_roadblock.obj',
        # 'assets/construction_site/cvx/coacd/construction_site_pickaxe.obj',
        # 'assets/construction_site/cvx/coacd/construction_site_obstacle1_beam.obj',
        # 'assets/construction_site/cvx/coacd/construction_site_obstacle1_roadcone.obj',
        # 'assets/construction_site/cvx/coacd/construction_site_obstacle2_drum.obj',
        # 'assets/construction_site/cvx/coacd/construction_site_obstacle2_roadblock.obj',
        'assets/assembly/cvx/coacd/peg_v5.obj',
        'assets/assembly/raw/hole_v5.obj',
        # 'assets/assembly/cvx/coacd/peg_var2.obj',
        # 'assets/assembly/raw/peg_var2.obj',
        # 'assets/assembly/raw/hole_var2.obj',
        # 'assets/assembly/cvx/coacd/hole_var4.obj',
        # 'assets/assembly/raw/hole_var4.obj',
        # 'assets/assembly/raw/peg_var4.obj',
        # 'assets/assembly/cvx/coacd/peg_var5.obj',
        # 'assets/assembly/raw/peg_var5.obj',
        # 'assets/assembly/cvx/coacd/hole_var6.obj',
        # 'assets/assembly/raw/hole_var6.obj',
        # 'assets/assembly/raw/peg_var6.obj',
        # 'assets/assembly/cvx/coacd/hole_var7.obj',
        # 'assets/assembly/raw/peg_var7.obj',
    ]

    for mesh_name in mesh_names:
        print(mesh_name)
        oriCORN = encode_mesh(models, mesh_name, niter=10000, use_col_loss=False, visualize=False)
    
    
    # oriCORN = models.asset_path_util.get_encoded_obj('assets/assembly/cvx/coacd/hole_var6.obj', models.pretrain_ckpt_id)

    rcutil.create_scene_mesh_from_oriCORNs(oriCORN, dec=jax.jit(models.occ_prediction), level=0.0, qp_bound=None, density=200, ndiv=100, visualize=True)