import numpy as np
import glob
import jax
import jax.numpy as jnp
import flax.linen as nn
import optax
import einops
from functools import partial
import datetime
import os
import shutil
import pickle
import copy
from pathlib import Path
from tqdm import tqdm
import random
from flax.struct import dataclass as flax_dataclass
from dataclasses import replace
import argparse
import wandb
import shlex, subprocess
import jax.debug as jdb
from scipy.spatial.transform import Rotation as sciR
import logging

try:
    import vessl
    vessl.init()
    vessl_on = True
except:
    vessl_on = False

from tensorboardX import SummaryWriter

import util.reconstruction_util as rcutil
import util.ev_util.rotm_util as rotm_util
import util.transform_util as tutil
from util.model_util import ColDecoderSceneV2
import util.model_util as mutil
import util.latent_obj_util as loutil
from util.structs import ColDataset, OccDataset
from util.dotenv_util import SSH_IP, SSH_PASSWORD, SSH_PORT


def debug_callback(inputs):
    print(f"inputs: {inputs}")

def tsdf_func(sd, slope_positive=0.1, slope_negative=0.2, negative_truncated_factor=2.0, truncated_val=0.06):

    ppos = np.array([truncated_val, 1.0])
    pneg = np.array([-truncated_val/negative_truncated_factor, -1.0])
    slpos_1 = 1/truncated_val
    slpos_2 = slope_positive
    slneg_1 = 1/truncated_val*negative_truncated_factor
    slneg_2 = slope_negative

    # sd pos 1
    sd_pos_1 = slpos_1*sd
    # sd pos 2
    sd_pos_2 = slpos_2*(sd - ppos[0]) + ppos[1]
    # sd neg 1
    sd_neg_1 = slneg_1*sd
    # sd neg 2
    sd_neg_2 = slneg_2*(sd - pneg[0]) + pneg[1]

    sd_gt_truncated = sd
    sd_gt_truncated = jnp.where(jnp.logical_and(sd<ppos[0], sd>=0), sd_pos_1, sd_gt_truncated)
    sd_gt_truncated = jnp.where(sd>=ppos[0], sd_pos_2, sd_gt_truncated)
    sd_gt_truncated = jnp.where(jnp.logical_and(sd>=pneg[0], sd<0), sd_neg_1, sd_gt_truncated)
    sd_gt_truncated = jnp.where(sd<pneg[0], sd_neg_2, sd_gt_truncated)

    return sd_gt_truncated

def col_sdf(col_gt, min_distance, threshold=0.04):
    slope_positive = 0.3
    slope_negative = 0.3
    negative_truncated_factor = 1.0
    sd_gt_truncated = -tsdf_func(min_distance, slope_positive, slope_negative, negative_truncated_factor, threshold)
    return sd_gt_truncated


def col_patch_label_np(obj_AB_idx, obj_AB_scale, obj_AB_pos, obj_AB_quat, col_patch_gt, col_pnts_AB, canonical_FPSs):

    # col_patch_gt = np.stack(col_res[0])
    # col_pnts_AB = np.stack([np.stack(col_res[3]), np.stack(col_res[4])], axis=-3)

    fps = canonical_FPSs[obj_AB_idx]
    fps = fps*obj_AB_scale[...,None,None]
    fps = tutil.qaction_np(obj_AB_quat[...,None,:], fps) + obj_AB_pos[...,None,:]

    fps_tf = fps
    pw_dist = np.linalg.norm(fps_tf[...,None,:] - col_pnts_AB[..., None, :, :], axis=-1)
    min_idx_per_col_pair = np.argmin(pw_dist, axis=-2)
    min_idx_per_col_pair = np.where(col_patch_gt[...,None,:], min_idx_per_col_pair, -1)
    # make it one-hot
    col_gt_patch = min_idx_per_col_pair[...,None] == np.arange(fps_tf.shape[-2])
    col_gt_patch = np.any(col_gt_patch, axis=-2)

    return col_gt_patch


def col_patch_label(obj_AB_idx, obj_AB_scale, obj_AB_pos, obj_AB_quat, col_patch_gt, col_pnts_AB, canonical_FPSs):

    fps = canonical_FPSs[obj_AB_idx]
    fps = fps*obj_AB_scale
    fps = tutil.qaction(obj_AB_quat[...,None,:], fps) + obj_AB_pos[...,None,:]

    fps_tf = fps
    col_patch_gt = einops.repeat(col_patch_gt, '... i -> ... (r i)', r=2)
    col_pnts_AB = einops.rearrange(col_pnts_AB, '... d p i -> ... 1 (d p) i')
    pw_dist = jnp.linalg.norm(fps_tf[...,None,:] - col_pnts_AB[..., None, :, :], axis=-1)
    pw_dist = jnp.where(col_patch_gt[...,None,None,:], pw_dist, jnp.inf)
    min_idx_per_col_pair = jnp.argmin(pw_dist, axis=-2)
    min_idx_per_col_pair = jnp.where(col_patch_gt[...,None,:], min_idx_per_col_pair, -1)
    # make it one-hot
    col_gt_patch = min_idx_per_col_pair[...,None] == jnp.arange(fps_tf.shape[-2])
    col_gt_patch = jnp.any(col_gt_patch, axis=-2)

    return col_gt_patch


def cal_min_distance(col_res):
    if len(col_res) == 3 or len(col_res) == 5:
        col_gt, min_distance, min_direction = col_res[:3]
        if col_gt:
            min_distance_out = -1.0
        else:
            min_distance_out = min_distance
    elif len(col_res) == 4:
        col_gt, min_distance, min_direction, penetration_depth = col_res
        if col_gt.ndim == 2:
            col_gt, min_distance, min_direction, penetration_depth = jax.tree_util.tree_map(lambda x: x[:,0], col_res)
        if np.abs(min_distance) < 1e-7:
            # data from uniform distribution
            if col_gt:
                min_distance_out = -1.0
            else:
                min_distance_out = min_distance
        elif col_gt:
            # collision and penetrated
            min_distance_out = np.minimum(-penetration_depth, 0)
        else:
            # other dataset - keep current min_distance
            min_distance_out = min_distance
    elif len(col_res) == 6:
        col_gt, min_distance, min_direction, _, _, penetration_depth = col_res
        if col_gt.ndim == 2:
            col_gt, min_distance, min_direction, _, _, penetration_depth = jax.tree_util.tree_map(lambda x: x[:,0], col_res)
        if col_gt.ndim == 1:
            min_distance_out = min_distance
            min_distance_out[col_gt] = np.minimum(-penetration_depth[col_gt]/500., 0)
            assert np.all(np.array(col_gt).astype(np.int32) == np.array(min_distance_out <= 0).astype(np.int32))
        else:
            if col_gt:
                # collision and penetrated
                min_distance_out = np.minimum(-penetration_depth/500., 0)
            else:
                # other dataset - keep current min_distance
                min_distance_out = min_distance
    else:
        raise ValueError('col_res should be either 3 or 4')
    # assert np.array(col_gt).astype(np.int32) == np.array(min_distance_out <= 0).astype(np.int32)
    return min_distance_out
    
def BCE_loss(col_pred_logit, col_gt):
    if col_pred_logit.ndim - col_gt.ndim == 1:
        col_pred_logit = col_pred_logit.squeeze(-1)
    elif col_pred_logit.ndim - col_gt.ndim == -1:
        col_gt = col_gt.squeeze(-1)
    col_pred = jax.nn.sigmoid(col_pred_logit)
    col_pred = col_pred.clip(1e-7, 1-1e-7)
    assert col_gt.shape == col_pred.shape
    col_loss_ = col_gt*jnp.log(col_pred) + (1-col_gt)*jnp.log(1-col_pred)
    col_loss_ = -jnp.sum(col_loss_)
    return col_loss_


def binary_focal_loss(logits, targets, alpha=0.75, gamma=2.0, apply_col_loss_mask:int=0):
    """
    Focal loss for binary classification in JAX.
    
    Args:
        logits: (B, ...) raw logits for the positive class.
        targets: (B, ...) in {0,1}.
        alpha: weight for positive examples.
        gamma: focusing parameter (usually 2.0).
        apply_col_loss_mask: 0: no mask, 1: mask for positive in patch, 2: mask for positive
    Returns:
        A scalar (mean) focal loss over the batch.
    """
    if logits.ndim - targets.ndim == 1:
        logits = logits.squeeze(-1)
    elif logits.ndim - targets.ndim == -1:
        targets = targets.squeeze(-1)

    assert logits.shape == targets.shape, f'logits: {logits.shape}, targets: {targets.shape}'

    # Sigmoid + clipping
    p = jax.nn.sigmoid(logits)
    p = jnp.clip(p, a_min=1e-7, a_max=1 - 1e-7)
    t = jax.lax.stop_gradient(targets.astype(jnp.float32))

    if apply_col_loss_mask==0:
        valid_loss_mask = jnp.ones_like(t, dtype=jnp.bool_)
    elif apply_col_loss_mask==1:
        valid_loss_mask = jnp.where(jnp.any(t>0.5, axis=-1, keepdims=True), t>0.5, True)
    elif apply_col_loss_mask==2:
        valid_loss_mask = jnp.where(jnp.any(t>0.5, axis=-1, keepdims=True), False, True)
    else:
        raise ValueError('apply_col_loss_mask should be either 0 or 1')

    # standard binary cross-entropy
    bce = - (t * jnp.log(p) + (1. - t) * jnp.log(1. - p))

    if args.binary_loss_type == 'focal':

        # alpha_t depends on whether target=1 (pos) or 0 (neg)
        alpha_t = alpha * t + (1. - alpha) * (1. - t)

        # p_t is p if y=1 else (1-p)
        p_t = t * p + (1. - t) * (1. - p)

        # focal scaling = alpha_t * (1 - p_t)^gamma
        focal_factor = alpha_t * (1. - p_t)**gamma
        loss = focal_factor * bce
    else:
        loss = bce

    return jnp.sum(jnp.where(valid_loss_mask, loss, 0))

def cal_grad_loss(fps_grad):
    grad_norm = jnp.linalg.norm(fps_grad, axis=-1)
    valid_norm_mask = grad_norm > 1e-6
    grad_norm_mean = jnp.sum(grad_norm) / jnp.maximum(jnp.sum(valid_norm_mask), 1)
    grad_norm_mean = jax.lax.stop_gradient(grad_norm_mean)
    grad_loss = jnp.mean(jnp.where(valid_norm_mask, jnp.abs(grad_norm - grad_norm_mean), 0.0))
    return grad_loss

def cal_col_loss(col_dataset:ColDataset, col_pred_logit, calculate_patch_col=False, apply_patch_col_mask=False, fps_grad=None):
    if fps_grad is not None:
        # grad_norm = jnp.linalg.norm(fps_grad, axis=-1)
        # grad_loss = jnp.sum(jnp.abs(grad_norm - jnp.mean(grad_norm)))
        grad_loss = cal_grad_loss(fps_grad)
        col_acc_dict = {'grad_loss': grad_loss}
    else:
        grad_loss = 0.0
        col_acc_dict = {}
    # if col_dataset.fps_col_labels is None:
    #     col_pred_logit = col_pred_logit[0]
        
    #     col_gt = col_dataset.col_gt.astype(jnp.int32)
    #     col_loss_ = binary_focal_loss(col_pred_logit, col_gt, alpha=0.5)
    #     col_acc_ = jnp.mean(col_gt == (col_pred_logit > 0.).astype(jnp.int32))
    #     col_acc_dict['col_acc'] = col_acc_

    #     # calculate F1 score
    #     TP = jnp.sum((col_gt == 1) & (col_pred_logit > 0.0))
    #     FP = jnp.sum((col_gt == 0) & (col_pred_logit > 0.0))
    #     FN = jnp.sum((col_gt == 1) & (col_pred_logit <= 0.0))
    #     precision = TP / (TP + FP + 1e-7)
    #     recall = TP / (TP + FN + 1e-7)
    #     f1_score = 2 * precision * recall / (precision + recall + 1e-7)
    #     col_acc_dict['col_f1'] = f1_score
    #     col_acc_dict['col_precision'] = precision
    #     col_acc_dict['col_recall'] = recall

    #     return col_loss_+args.grad_loss_weight*grad_loss, col_acc_dict
    # else:
    col_loss = 0
    if col_dataset.fps_col_labels is not None:
        gt_list = [col_dataset.col_gt, col_dataset.fps_col_labels[...,0,:], col_dataset.fps_col_labels[...,1,:]]
        name_list = ['col', 'col_patch_A', 'col_patch_B']
    else:
        gt_list = [col_dataset.col_gt]
        name_list = ['col']
    alpha_list = [0.5, 0.8, 0.8]
    # loss_weight = [1, 0.5, 0.5]
    loss_weight = [1, 1, 1]
    if apply_patch_col_mask:
        apply_col_loss_mask_bool = [0, 1, 1]
    else:
        apply_col_loss_mask_bool = [0, 0, 0]
    for idx, col_pred_logit_ in enumerate(col_pred_logit):
        col_gt = gt_list[idx]
        name = name_list[idx]

        # focal loss
        col_loss_ = binary_focal_loss(col_pred_logit_, col_gt, alpha=alpha_list[idx], apply_col_loss_mask=apply_col_loss_mask_bool[idx])

        col_loss += col_loss_ * loss_weight[idx]

        col_pred_logit_ = col_pred_logit_.squeeze(-1)

        # col accuracy
        col_acc_ = jnp.mean((col_gt).astype(jnp.int32) == (col_pred_logit_ > 0.).astype(jnp.int32))
        col_acc_dict[name + '_acc'] = col_acc_

        # calculate F1 score
        TP = jnp.sum((col_gt == 1) & (col_pred_logit_ > 0.))
        FP = jnp.sum((col_gt == 0) & (col_pred_logit_ > 0.))
        FN = jnp.sum((col_gt == 1) & (col_pred_logit_ <= 0.))
        precision = TP / (TP + FP + 1e-7)
        recall = TP / (TP + FN + 1e-7)
        f1_score = 2 * precision * recall / (precision + recall + 1e-7)
        col_acc_dict[name+'_precision'] = precision
        col_acc_dict[name+'_recall'] = recall
        col_acc_dict[name+'_f1'] = f1_score

        if (not calculate_patch_col) or (col_dataset.fps_col_labels is None):
            break
            
    return col_loss+args.grad_loss_weight*grad_loss, col_acc_dict



def create_obj_dataset(sdf_data_path, jkey, rel_path=None, visualize=False):

    if not os.path.exists(sdf_data_path):
        os.makedirs(os.path.dirname(sdf_data_path), exist_ok=True)
        remote_escaped = f"research/object_set/{os.path.dirname(rel_path)}"
        local_escaped = os.path.dirname(os.path.dirname(sdf_data_path))
        # download from server
        # command = f"sshpass -p {SSH_PASSWORD} scp -r -P {SSH_PORT} {SSH_IP}:{shlex.quote(remote_escaped)} {shlex.quote(local_escaped)}"
        command = f"sshpass -p {SSH_PASSWORD} rsync --ignore-existing -av -e 'ssh -p {SSH_PORT}' {SSH_IP}:{shlex.quote(remote_escaped)} {shlex.quote(local_escaped)}"
        process = subprocess.run(shlex.split(command), check=True)
        if process.returncode != 0:
            print(f"Error occurred, return code: {process.returncode}")
        print(f'{sdf_data_path} is downloaded from the server')

    with open(sdf_data_path, 'rb') as f:
        occ_data = pickle.load(f)
    
    surface_pnts = occ_data['surface_points']
    query_points = occ_data['query_points'][:DATASET_SIZE]
    signed_distance = occ_data['signed_distance'][:DATASET_SIZE]
    normal_vector = occ_data['normal_vector'][:DATASET_SIZE]

    rays = occ_data['rays'][:DATASET_SIZE]
    t_hit = occ_data['t_hit'][:DATASET_SIZE,...,None]
    
    # jkey = jax.random.PRNGKey(42)
    if args.global_rep_baseline:
        # tree_param = jnp.zeros((1, 8, TREE_FEAT_DIM))
        jkey, subkey = jax.random.split(jkey)
        tree_param = 0.05*jax.random.normal(subkey, shape=(1, 8, TREE_FEAT_DIM)).astype(jnp.float32)
        root_pnts = jnp.zeros((1, 3)).astype(jnp.float32)
    else:
        if args.noise_init_z:
            jkey, subkey = jax.random.split(jkey)
            tree_param = 0.05*jax.random.normal(subkey, shape=(TREE_WIDTH, 8, TREE_FEAT_DIM)).astype(jnp.float32)
        else:
            tree_param = jnp.zeros((TREE_WIDTH, 8, TREE_FEAT_DIM)).astype(jnp.float32)

        # surface points
        jkey, subkey = jax.random.split(jkey)
        root_pnts = rcutil.FPS_padding(surface_pnts, TREE_WIDTH, jkey)

    normal_vector = None
    if args.use_encoder:
        latent_obj = None
        occ_dataset = OccDataset(query_points, signed_distance, normal_vector, rays, t_hit, surface_pnts[...,:args.enc_train_npnts*3, :])
    else:
        latent_obj = loutil.LatentObjects().set_z_list(root_pnts, tree_param)
        occ_dataset = OccDataset(query_points, signed_distance, normal_vector, rays, t_hit)
    return occ_dataset, latent_obj


def path_col_datagen(col_datapoints:ColDataset, oriCORNs_pair:loutil.LatentObjects, jkey, debug=False):

    npath = 5
    col_gt = col_datapoints.col_gt
    min_direction = col_datapoints.min_direction
    outer_shape = min_direction.shape[:-1] # (NB, )

    jkey, subkey = jax.random.split(jkey)
    path_direction = jax.random.normal(subkey, shape=outer_shape + (2, 3))
    path_direction_perp = jnp.cross(min_direction[...,None, :], path_direction)
    path_direction = jnp.where(col_gt[...,None,None], path_direction, path_direction_perp)
    path_direction = path_direction/jnp.linalg.norm(path_direction, axis=-1, keepdims=True) # (NB, 2, 3)

    jkey, subkey = jax.random.split(jkey)
    path_len = jax.random.uniform(subkey, shape=outer_shape+(2, 1,), minval=0.05, maxval=0.5)
    jkey, subkey = jax.random.split(jkey)
    path_distance = jax.random.uniform(subkey, shape=outer_shape+(2, npath,)) * path_len # (NB, 2, npath)
    jkey, subkey = jax.random.split(jkey)
    random_idx = jax.random.randint(subkey, shape=outer_shape+(2, 1,), minval=0, maxval=path_distance.shape[-1])
    path_distance = path_distance - jnp.take_along_axis(path_distance, random_idx, axis=-1) # (NB, 2, npath) # centering with distance
    path_ptb = path_direction[...,None,:] * path_distance[...,None] # (NB, 2, npath, 3)

    oriCORNs_pair = oriCORNs_pair.extend_and_repeat_outer_shape(npath, -1)
    oriCORNs_pair = oriCORNs_pair.translate(path_ptb)

    merged_oriCORNs_A, merged_oriCORNs_B = oriCORNs_pair[:,0], oriCORNs_pair[:,1]

    if debug:
        # visualize path
        import open3d as o3d
        for i in range(outer_shape[0]):
            pcd_vos_list = []
            for j in range(2):
                for k in range(npath):
                    path_pcd = o3d.geometry.PointCloud()
                    path_pcd.points = o3d.utility.Vector3dVector(np.array(oriCORNs_pair[i,j,k].fps_tf))
            coordinate_frame = o3d.geometry.TriangleMesh.create_coordinate_frame(size=1.2, origin=np.array([0., 0., 0.]))
            o3d.visualization.draw_geometries([path_pcd, coordinate_frame], point_show_normal=True)

    return merged_oriCORNs_A, merged_oriCORNs_B


if __name__ == '__main__':

    parser = argparse.ArgumentParser()
    # convert variables to argparse
    parser.add_argument("--dataset_size", type=int, default=80000)
    parser.add_argument("--obj_batch_size", type=int, default=64)
    parser.add_argument("--col_dataset_load_no", type=int, default=100000)
    parser.add_argument("--query_batch_size", type=int, default=64)
    parser.add_argument("--dropout_rate", type=float, default=0.8)
    parser.add_argument("--sdf_loss_weight", type=float, default=0.3)
    parser.add_argument("--col_loss_weight", type=float, default=1)
    parser.add_argument("--ray_loss_weight", type=float, default=0.3)
    parser.add_argument("--grad_loss_weight", type=float, default=3.0)
    parser.add_argument("--multihead_no", type=int, default=4)
    parser.add_argument("--tree_feat_dim", type=int, default=2)
    parser.add_argument("--tree_width", type=int, default=64)
    parser.add_argument("--col_reduce_k", type=int, default=20)
    parser.add_argument("--ckpt_dir", type=str, default=None)
    parser.add_argument("--sdf_base_dir", type=str, default='/home/dongwon/research/object_set')
    parser.add_argument("--col_data_base_dir", type=str, default=None)
    parser.add_argument("--inner_itr_no", type=int, default=50)
    parser.add_argument("--reg_loss_weight", type=float, default=1)
    parser.add_argument("--global_rep_baseline", type=int, default=0)
    parser.add_argument("--update_sdf_dir", type=str, default=None)
    parser.add_argument("--add_path_col_loss", type=int, default=1)
    parser.add_argument("--rotm_type", type=str, default='custom')
    parser.add_argument("--save_interval", type=int, default=1000)
    parser.add_argument("--col_dec_version", type=int, default=3, help='deprecated')
    parser.add_argument("--binary_loss_type", type=str, default='bce', choices=['focal', 'bce'])
    parser.add_argument("--train_rot_aug", type=int, default=1)
    parser.add_argument("--train_patch_col", type=int, default=0)
    parser.add_argument("--col_dec_depth", type=int, default=3)
    parser.add_argument("--pos_emb_size_divider", type=int, default=1)
    parser.add_argument("--feat_size_divider", type=int, default=1)
    parser.add_argument("--noise_init_z", type=int, default=1)
    parser.add_argument("--normalize_eps", type=float, default=0.01)
    parser.add_argument("--use_encoder", type=int, default=1)
    parser.add_argument("--enc_train_npnts", type=int, default=1024)
    parser.add_argument("--debug", type=int, default=0)
    
    

    args = parser.parse_args()

    if args.global_rep_baseline == 1:
        print('global_rep_baseline is activated')
        args.train_patch_col = 0
        args.dropout_rate = 1.0
        args.tree_width = 1

    if args.col_dec_version == 1:
        args.train_rot_aug = 1

    # Replace variables with argparse
    DATASET_SIZE = args.dataset_size
    OBJ_BATCH_SIZE = args.obj_batch_size
    COL_DATASET_LOAD_NO = args.col_dataset_load_no
    QUERY_BATCH_SIZE = args.query_batch_size
    DROPOUT_RATE = args.dropout_rate
    TREE_FEAT_DIM = args.tree_feat_dim
    TREE_WIDTH = args.tree_width
    CKPT_DIR = args.ckpt_dir

    jkey = jax.random.PRNGKey(0)

    if jax.device_count('gpu') == 0:
        print('no gpu found. End process')
        raise ValueError('no gpu found. End process')
    else:
        print('device found: ', jax.devices())

    LOG_DIR = 'logs/' + datetime.datetime.now().strftime("%m%d%Y_%H%M%S")
    # rot_configs = rotm_util.init_rot_config(seed=0, dim_list=[1,2], rot_type='wigner')
    rot_configs = rotm_util.init_rot_config(seed=0, dim_list=[1,2], rot_type=args.rotm_type)

    # # load obj files from txt
    with open(f'dataset/sdf_dirs.txt', 'r') as f:
        sdf_paths = f.readlines()
    sdf_paths = [of.strip() for of in sdf_paths]

    if args.update_sdf_dir is not None:
        assert args.ckpt_dir is not None
        with open(os.path.join(CKPT_DIR, 'sdf_dirs.txt'), 'r') as f:
            sdf_paths_old = f.readlines()
        sdf_paths_old = [of.strip() for of in sdf_paths_old]
        # update index
        old_to_new_idx = np.array([sdf_paths.index(of) for of in sdf_paths_old])


    # if args.update_sdf_dir is not None:
    #     assert args.ckpt_dir is not None
    #     with open(args.update_sdf_dir, 'r') as f:
    #         obj_files_new = f.readlines()
    #     obj_files_new = [of.strip() for of in obj_files_new]

    #     # update index
    #     old_to_new_idx = np.array([obj_files_new.index(of) for of in sdf_paths])
    #     sdf_paths = obj_files_new

    sdf_paths = [of.split('object_set/')[-1] for of in sdf_paths]
    
    rel_sdf_paths = copy.deepcopy(sdf_paths)

    if args.sdf_base_dir is not None:
        sdf_paths = [os.path.join(args.sdf_base_dir, of[1:] if of[0]=='/' else of) for of in sdf_paths]
    obj_cls = []
    new_obj_files = []
    latent_obj_list = []
    dsname_list = []
    for i, of in enumerate(tqdm(sdf_paths)):
        try:
            jkey, subkey = jax.random.split(jkey)
            occ_ds, latent_obj_ = create_obj_dataset(of, subkey, rel_sdf_paths[i])
            obj_cls.append(occ_ds)
            latent_obj_list.append(latent_obj_)
            new_obj_files.append(of)
            dsname_list.append(of.split('/')[-3])
        except Exception as e:
            print(f"AssertionError: {e}")

    # check whether no mismatch between obj_files and new_obj_files
    for of, nof in zip(sdf_paths, new_obj_files):
        assert of == nof

    sdf_paths = new_obj_files
    obj_no = len(sdf_paths)
    print(f'entire number od objects: {obj_no}')


    # load scale and transformation
    mesh_translation_to_origin_list = []
    mesh_scale_to_origin_list = []
    mesh_obj_path = []
    for i, of in enumerate(tqdm(sdf_paths)):
        with open(of, 'rb') as f:
            occ_data_ = pickle.load(f)
            mesh_translation_to_origin_list.append(occ_data_['translation'])
            mesh_scale_to_origin_list.append(occ_data_['scale'])
            mesh_obj_path.append(occ_data_['path'])

    rel_mesh_obj_path = [of.split('object_set/')[-1] for of in mesh_obj_path]

    occ_dataset:OccDataset = jax.tree_util.tree_map(lambda *x: jnp.stack(x, 0), *obj_cls)  # upload dataset to the GPU memory
    latent_obj_list:loutil.LatentObjects = jax.tree_util.tree_map(lambda *x: jnp.stack(x, 0), *latent_obj_list)

    # load collision dataset from col_data/*.pkl
    print('loading collision dataset...')
    col_data_base_dir = os.path.join('col_data_patch5')
    if args.col_data_base_dir is not None:
        col_data_base_dir = os.path.join(args.col_data_base_dir, col_data_base_dir)
    if not os.path.exists(col_data_base_dir):
        print('downloading collision dataset...')
        os.makedirs(col_data_base_dir, exist_ok=True)
        command = f"sshpass -p {SSH_PASSWORD} scp -r -P {SSH_PORT} {SSH_IP}:research/PointObjRep/col_data_patch4 {os.path.dirname(col_data_base_dir)}"
        process = subprocess.run(shlex.split(command), check=True)
        if process.returncode != 0:
            print(f"Error occurred, return code: {process.returncode}")
    
    col_data_dir_list = glob.glob(os.path.join(col_data_base_dir, '*.pkl'))
    
    if args.train_patch_col:
        canonical_FPSs = np.array(latent_obj_list.init_pos_zero().fps_tf)
        col_patch_label_jit = jax.jit(partial(col_patch_label, canonical_FPSs=jnp.array(canonical_FPSs)))
    col_data_list = []
    for col_data_dir in tqdm(col_data_dir_list[:COL_DATASET_LOAD_NO]):
        with open(col_data_dir, "rb") as f:
            obj_AB_idx, obj_AB_scale, obj_AB_pos, obj_AB_quat, col_res_list = (
                pickle.load(f)
            )
        if len(obj_AB_idx) == 0:
            continue
        
        if args.train_patch_col:
            fps_col_data_dir = os.path.join(os.path.dirname(col_data_dir), f'fps_col_data_{args.tree_width}', os.path.basename(col_data_dir).split('.')[0] + '.npy')
            os.makedirs(os.path.dirname(fps_col_data_dir), exist_ok=True)
            if os.path.exists(fps_col_data_dir):
                fps_col_labels = np.load(fps_col_data_dir)
            else:
                col_patch_gt = np.stack([cr[0] for cr in col_res_list])
                closest_pnt_A = np.stack([cr[3] for cr in col_res_list])
                closest_pnt_B = np.stack([cr[4] for cr in col_res_list])
                closest_pnt_AB = np.stack([closest_pnt_A, closest_pnt_B], axis=-3)
                fps_col_labels = col_patch_label_jit(jnp.stack(obj_AB_idx), jnp.stack(obj_AB_scale), jnp.stack(obj_AB_pos), jnp.stack(obj_AB_quat), col_patch_gt, closest_pnt_AB)
                fps_col_labels = np.array(fps_col_labels)
                np.save(fps_col_data_dir, fps_col_labels)
        else:
            fps_col_labels = None

        col_data_list.append(
            ColDataset(
                obj_idx=np.stack(obj_AB_idx),
                obj_scale=np.stack(obj_AB_scale),
                obj_pos=np.stack(obj_AB_pos),
                obj_quat=np.stack(obj_AB_quat),
                col_gt=np.stack([cr[0][0] for cr in col_res_list]),
                # distance=np.stack([cal_min_distance(cr) for cr in col_res_list]),
                # distance=cal_min_distance(jax.tree_util.tree_map(lambda *x: np.stack(x, axis=0), *col_res_list)),
                distance=None,
                min_direction=np.stack([cr[2][0] for cr in col_res_list]),
                fps_col_labels=fps_col_labels,
            )
        )

        # visualize dataset
        # col_data_list[-1].visualize_fps_col_labels(latent_obj_list, rot_configs)

    col_data_list = jax.tree_util.tree_map(lambda *x: np.concatenate(x, axis=0), *col_data_list)

    # save entire col dataset

    # shuffle col data
    entire_index = np.random.default_rng(seed=42).permutation(len(col_data_list))
    col_data_list_train = jax.tree_util.tree_map(lambda x: x[entire_index[:int(x.shape[0]*0.9)]], col_data_list)
    col_data_list_test = jax.tree_util.tree_map(lambda x: jnp.array(x[entire_index[int(x.shape[0]*0.9):]]), col_data_list)
    print(f'collision dataset is loaded / train: {len(col_data_list_train)} / test: {len(col_data_list_test)}')

    # load collision decoders
    if args.use_encoder:
        enc = mutil.ShapeEncoder(rot_configs, output_feat_dim=TREE_FEAT_DIM, nfps=TREE_WIDTH)
        jkey, subkey = jax.random.split(jkey)
        test_oriCORNs, enc_params = enc.init_with_output(subkey, occ_dataset.surface_points[:2,:args.enc_train_npnts], jkey=jkey)
    else:
        enc_params = None
        test_oriCORNs = latent_obj_list[:2].init_pos_zero()
    col_dec = ColDecoderSceneV2(args.dropout_rate, rot_configs, version=args.col_dec_version, multihead_no=args.multihead_no,
                                depth=args.col_dec_depth, pos_emb_size_divider=args.pos_emb_size_divider,
                                feat_size_divider=args.feat_size_divider,
                                normalize_eps=args.normalize_eps)
    jkey, subkey = jax.random.split(jkey)
    col_dec_params = col_dec.init(subkey, test_oriCORNs[0], test_oriCORNs[1], jkey=jkey)

    params = (col_dec_params, enc_params, latent_obj_list)

    def loss_func(params, jkey, col_data_minibatch:ColDataset, occ_data_minibatch:OccDataset, mulcol_data_minibatch:ColDataset, rot_aug=False, train=True):

        if args.use_encoder:
            sampled_obj_params = enc.apply(params[1], occ_data_minibatch.surface_points, jkey=jkey, train=train)
        else:
            # apply stop gradient to canonical object
            canonical_obj_params = params[-1].init_pos_zero()
            canonical_obj_params = replace(canonical_obj_params, pos=jax.lax.stop_gradient(canonical_obj_params.pos))
            canonical_obj_params = replace(canonical_obj_params, rel_fps=jax.lax.stop_gradient(canonical_obj_params.rel_fps))
            
            # if args.sdf_loss_weight!=0:
            sampled_obj_params = canonical_obj_params[col_data_minibatch.obj_idx]
            # selected_latent_obj_pairs = col_data_minibatch.make_latent_obj(canonical_obj_params, rot_configs) # collision pair for collision dataset

        selected_latent_obj_pairs = col_data_minibatch.make_latent_obj(None, rot_configs, latent_obj=sampled_obj_params) # collision pair for collision dataset
        sampled_obj_params = sampled_obj_params.reshape_outer_shape((-1,))
        occ_data_minibatch = jax.tree_util.tree_map(lambda x: einops.rearrange(x, 'n m ... -> (n m) ...'), occ_data_minibatch)
        sd_gt = occ_data_minibatch.signed_distance
        qpnts = occ_data_minibatch.query_points
        rays = occ_data_minibatch.rays
        ray_t_hit = occ_data_minibatch.t_hit
        ray_hitting_pnts = rays[...,:3] + ray_t_hit * tutil.normalize(rays[...,3:])
        ray_hitting_fps_idx = jnp.argmin(jnp.linalg.norm(sampled_obj_params.fps_tf[...,None,:,:] - ray_hitting_pnts[...,None,:], axis=-1), axis=-1)

        # rotational augmentation
        nobj = qpnts.shape[0]
        jkey, subkey = jax.random.split(jkey)
        random_scale_queries = jax.random.uniform(subkey, minval=0.03, maxval=1.2, shape=(nobj,1,1))
        qpnts = qpnts * random_scale_queries
        sampled_obj_params = sampled_obj_params.apply_scale(random_scale_queries)
        jkey, subkey = jax.random.split(jkey)
        random_quat = tutil.qrand((nobj,), subkey)
        qpnts = tutil.qaction(random_quat[...,None,:], qpnts)
        sampled_obj_params = sampled_obj_params.rotate_z(random_quat, rot_configs)
        rays = jnp.c_[random_scale_queries*tutil.qaction(random_quat[...,None,:], rays[...,:3]), tutil.qaction(random_quat[...,None,:], rays[...,3:])]
        ray_t_hit = ray_t_hit*random_scale_queries
    
        # selected_latent_obj_pairs = col_data_minibatch.make_latent_obj(canonical_obj_params, rot_configs)
        # selected_latent_obj_pairs = col_data_minibatch.make_latent_obj(None, rot_configs)
        if args.add_path_col_loss:
            oriCORNs_wo_rot_aug = selected_latent_obj_pairs
            path_obj_A, path_obj_B = path_col_datagen(col_data_minibatch, selected_latent_obj_pairs, jkey)
        if rot_aug:
            npair = selected_latent_obj_pairs.shape[0]
            jkey, subkey = jax.random.split(jkey)
            random_scale = jax.random.uniform(subkey, minval=0.03, maxval=1.2, shape=(npair,1,1))
            jkey, subkey = jax.random.split(jkey)
            random_quat = tutil.qrand((npair,1), subkey)
            jkey, subkey = jax.random.split(jkey)
            random_translation = jax.random.uniform(subkey, minval=-1.2, maxval=1.2, shape=(npair,1,3))
            selected_latent_obj_pairs = selected_latent_obj_pairs.apply_scale(random_scale, center=jnp.mean(selected_latent_obj_pairs.pos, axis=1, keepdims=True))
            selected_latent_obj_pairs = selected_latent_obj_pairs.apply_pq_z(random_translation, random_quat, rot_configs)

        # fps_col_labels = col_patch_label_jit(selected_latent_obj_pairs[:,0], selected_latent_obj_pairs[:,1], col_data_minibatch, rot_configs)


        def fps_grad_func(fps_pnts, objA, objB, **kwargs):
            objB_tmp = objB
            objB_tmp = objB_tmp.replace(rel_fps=objB_tmp.rel_fps + fps_pnts)
            path_col_logits = col_dec.apply(params[0], objA, objB_tmp, **kwargs)
            return jnp.sum(path_col_logits[2]), path_col_logits

        if args.global_rep_baseline:
            reduce_k_devider_pool = [1]
        else:
            reduce_k_devider_pool = [1, 2]
        col_loss = 0
        path_col_loss = 0
        occ_loss = 0
        ray_loss = 0
        col_acc_dict = {}
        path_col_acc_dict = {}
        for idx, reduce_k_devider in enumerate(reduce_k_devider_pool):
            jkey, subkey = jax.random.split(jkey, 2)
            if args.grad_loss_weight!=0:
                fps_grad, col_logits = jax.grad(fps_grad_func, has_aux=True)(jnp.zeros_like(selected_latent_obj_pairs[:,1].rel_fps), 
                                                            selected_latent_obj_pairs[:,0], selected_latent_obj_pairs[:,1], 
                                                                reduce_k=args.col_reduce_k//reduce_k_devider, jkey=subkey, train=train)
            else:
                col_logits = col_dec.apply(params[0], selected_latent_obj_pairs[:,0], 
                                        selected_latent_obj_pairs[:,1], 
                                        reduce_k=args.col_reduce_k//reduce_k_devider, jkey=subkey, train=train)
                fps_grad = None
            col_loss_, col_acc_dict_ = cal_col_loss(col_data_minibatch, col_logits, 
                                                    calculate_patch_col=args.train_patch_col==1, fps_grad=fps_grad)
            col_acc_dict = {**col_acc_dict, **{k+f"_{idx}": v for k, v in col_acc_dict_.items()}}
            col_loss += col_loss_
            if args.add_path_col_loss:
                line_segment_B = path_obj_B[:,-1].pos - path_obj_B[:,0].pos

                jkey, subkey = jax.random.split(jkey)

                if args.grad_loss_weight!=0:
                    path_fps_grad, path_col_logits = jax.grad(fps_grad_func, has_aux=True)(jnp.zeros_like(path_obj_B[:,0].rel_fps), oriCORNs_wo_rot_aug[:,0], path_obj_B[:,0],
                                                                                        line_segment_B=line_segment_B, reduce_k=args.col_reduce_k//reduce_k_devider, jkey=subkey, train=train)
                else:
                    path_col_logits = col_dec.apply(params[0], oriCORNs_wo_rot_aug[:,0], 
                                                    path_obj_B[:,0], line_segment_B=line_segment_B, 
                                                    reduce_k=args.col_reduce_k//reduce_k_devider, jkey=subkey, train=train)
                    path_fps_grad = None

                path_col_loss_, path_col_acc_dict_ = cal_col_loss(col_data_minibatch, path_col_logits, 
                                                                calculate_patch_col=args.train_patch_col==1, 
                                                                apply_patch_col_mask=1, fps_grad=path_fps_grad)
                path_col_acc_dict ={**path_col_acc_dict, **{"path_"+k+f"_{idx}": v for k, v in path_col_acc_dict_.items()}}
                path_col_loss += path_col_loss_

            if args.sdf_loss_weight!=0:
                sampled_obj_params:loutil.LatentObjects
                # point_oriCORN = params[1]
                point_oriCORN = loutil.LatentObjects(pos=jnp.zeros((3,), dtype=jnp.float32), 
                                                     rel_fps=jnp.zeros((1, 3), dtype=jnp.float32), 
                                                     z=jnp.zeros((1, *sampled_obj_params.latent_shape[-2:]), dtype=jnp.float32))
                # point_oriCORN = point_oriCORN.replace(z=jnp.zeros_like(point_oriCORN.z))
                point_oriCORN:loutil.LatentObjects = point_oriCORN[None,None].translate(qpnts)
                point_oriCORN = point_oriCORN.broadcast_outershape(qpnts.shape[:-1])
                point_oriCORN = jax.lax.stop_gradient(point_oriCORN) # just zero vector
                sampled_obj_params_rp:loutil.LatentObjects = sampled_obj_params.extend_and_repeat_outer_shape(qpnts.shape[-2], -1)

                jkey, subkey = jax.random.split(jkey)

                if args.grad_loss_weight!=0:
                    occ_fps_grad, (occ_logit, occ_patch_logit, occ_pred2) = jax.grad(fps_grad_func, has_aux=True)(jnp.zeros_like(point_oriCORN.rel_fps), 
                                                                                        sampled_obj_params_rp, point_oriCORN,
                                                                                    reduce_k=args.col_reduce_k//reduce_k_devider, jkey=subkey, train=train)
                    occ_grad_loss = cal_grad_loss(occ_fps_grad)
                else:
                    occ_logit, occ_patch_logit, occ_pred2 = col_dec.apply(params[0], sampled_obj_params_rp, point_oriCORN, 
                                                            reduce_k=args.col_reduce_k//reduce_k_devider, jkey=subkey, train=train)
                    occ_grad_loss = 0
                sd_loss_ = binary_focal_loss(occ_logit, sd_gt<0, alpha=0.5)
                sd_loss_ += binary_focal_loss(occ_patch_logit, (sd_gt<0).repeat(occ_patch_logit.shape[-2], -1), alpha=0.5, apply_col_loss_mask=2)
                sd_loss_ += occ_grad_loss*args.grad_loss_weight
                sd_acc = jnp.mean((occ_logit > 0).astype(jnp.int32) == (sd_gt < 0).astype(jnp.int32))
                col_acc_dict = {**col_acc_dict, f'sd_acc_{idx}': sd_acc, f'occ_grad_loss_{idx}': occ_grad_loss}
            else:
                sd_loss_ = 0
                sd_acc = 0
            occ_loss += sd_loss_

            if args.ray_loss_weight==0:
                ray_seg_loss_ = 0
                seg_acc = 0
            else:
                # ray loss
                ray_start_pnt, ray_dir = rays[...,:3], rays[...,3:]
                jkey, subkey = jax.random.split(jkey)
                # ray_oriCORN = params[1]
                ray_oriCORN = loutil.LatentObjects(pos=jnp.zeros((3,), dtype=jnp.float32), 
                                                     rel_fps=jnp.zeros((1, 3), dtype=jnp.float32), 
                                                     z=jnp.zeros((1, *sampled_obj_params.latent_shape[-2:]), dtype=jnp.float32))
                # ray_oriCORN = ray_oriCORN.replace(z=jnp.zeros_like(ray_oriCORN.z))
                ray_oriCORN:loutil.LatentObjects = ray_oriCORN[None,None].translate(ray_start_pnt)
                ray_oriCORN = ray_oriCORN.broadcast_outershape(ray_start_pnt.shape[:-1])
                ray_oriCORN = jax.lax.stop_gradient(ray_oriCORN) # just zero vector

                if args.grad_loss_weight!=0:
                    seg_fps_grad, (seg_logit, seg_patch_logit, _) = jax.grad(fps_grad_func, has_aux=True)(jnp.zeros_like(ray_oriCORN.rel_fps), 
                                                                                        sampled_obj_params_rp, ray_oriCORN,
                                                                                    line_segment_B=10*ray_dir, reduce_k=args.col_reduce_k//reduce_k_devider, jkey=subkey, train=train)
                    seg_grad_loss = cal_grad_loss(seg_fps_grad)
                else:
                    seg_logit, seg_patch_logit, _ = col_dec.apply(params[0], sampled_obj_params_rp, ray_oriCORN, line_segment_B=10*ray_dir,
                                                    reduce_k=args.col_reduce_k//reduce_k_devider, jkey=subkey, train=train)
                    seg_grad_loss = 0
                ray_seg_gt = (ray_t_hit >= 0).astype(jnp.float32)
                ray_seg_loss_ = binary_focal_loss(seg_logit, ray_seg_gt, alpha=0.5)
                seg_patch_label = jax.nn.one_hot(ray_hitting_fps_idx, num_classes=sampled_obj_params_rp.nfps, axis=-1)
                seg_patch_label = jnp.where(ray_seg_gt, seg_patch_label, 0)
                ray_seg_loss_ += binary_focal_loss(seg_patch_logit, seg_patch_label, alpha=0.8, apply_col_loss_mask=1)
                ray_seg_loss_ += seg_grad_loss*args.grad_loss_weight
                seg_acc = jnp.mean((ray_seg_gt).astype(jnp.int32) == (seg_logit > 0.).astype(jnp.int32))
                col_acc_dict = {**col_acc_dict, f'ray_seg_acc_{idx}': seg_acc, f'seg_grad_loss_{idx}': seg_grad_loss}


                nray_ptb = ray_start_pnt.shape[-2]//3
                hit_idx = jax.vmap(partial(jnp.where, size=nray_ptb))(ray_t_hit.squeeze(-1) >= 0)[0]
                sampled_obj_params_rp_hit = sampled_obj_params_rp.take_along_outer_axis(hit_idx, axis=1)
                ray_t_hit_hit = jnp.take_along_axis(ray_t_hit, hit_idx[...,None], axis=-2)
                ray_dir_hit = jnp.take_along_axis(ray_dir, hit_idx[...,None], axis=-2)
                ray_dir_hit = tutil.normalize(ray_dir_hit)
                ray_oriCORN_hit = ray_oriCORN.take_along_outer_axis(hit_idx, axis=1)
                ray_hitting_fps_idx_hit = jnp.take_along_axis(ray_hitting_fps_idx, hit_idx, axis=-1)

                if sampled_obj_params_rp_hit.nfps > 4:
                    obj_len = sampled_obj_params_rp_hit.mean_fps_dist[...,None]
                else:
                    obj_len = random_scale_queries/4
                    for _ in range(obj_len.ndim - obj_len.ndim):
                        obj_len = obj_len[...,None]
                ray_seg_mask = (ray_t_hit_hit >= 0).astype(jnp.float32)
                jkey, subkey1, subkey2, subkey3 = jax.random.split(jkey, 4)
                ray_t_hit_ptb_pos = jax.random.uniform(subkey1, minval=ray_t_hit_hit, maxval=ray_t_hit_hit+2*obj_len, shape=ray_t_hit_hit.shape) # collision
                ray_t_hit_ptb_neg = jax.random.uniform(subkey2, minval=ray_t_hit_hit-1*obj_len, maxval=ray_t_hit_hit, shape=ray_t_hit_hit.shape) # no collision
                ray_hit_ptb_mask = jax.random.uniform(subkey3, shape=ray_t_hit_hit.shape)>1/2
                ray_t_hit_ptb = jnp.where(ray_hit_ptb_mask, ray_t_hit_ptb_pos, ray_t_hit_ptb_neg)
                ray_start_t = jax.random.uniform(subkey2, minval=ray_t_hit_hit-obj_len*20.0, maxval=ray_t_hit_hit-obj_len*0.5, shape=ray_t_hit_hit.shape)
                ray_t_hit_ptb = ray_t_hit_ptb - ray_start_t
                ray_t_hit_ptb = jnp.where(ray_seg_mask, ray_t_hit_ptb, 10.0)
                ray_seg_gt_hit = jnp.where(ray_seg_mask, ray_hit_ptb_mask, False)
                ray_start_t = jnp.where(ray_seg_mask, ray_start_t, 0)

                ray_oriCORN_hit = ray_oriCORN_hit.replace(pos=ray_oriCORN_hit.pos + ray_start_t*ray_dir_hit)

                if args.debug:
                    # visualize points
                    ray_surface_pnts = ray_oriCORN_hit.pos + (ray_t_hit_hit-ray_start_t)*ray_dir_hit
                        
                # sometimes, flip line segment direction!
                jkey, subkey = jax.random.split(jkey, 2)
                flip_mask = jax.random.uniform(subkey, shape=ray_t_hit_hit.shape)>1/2
                ray_oriCORN_hit = ray_oriCORN_hit.replace(pos=jnp.where(flip_mask, ray_oriCORN_hit.pos + ray_t_hit_ptb*ray_dir_hit, ray_oriCORN_hit.pos))
                ray_dir_hit = jnp.where(flip_mask, -ray_dir_hit, ray_dir_hit)
                ray_oriCORN_hit = jax.lax.stop_gradient(ray_oriCORN_hit)
                ray_dir_hit = jax.lax.stop_gradient(ray_dir_hit)

                if args.debug:
                    # visualize points
                    # ray_surface_pnts = ray_oriCORN_hit.pos + ray_t_hit_hit*ray_dir_hit
                    import open3d as o3d
                    for i in range(sampled_obj_params_rp_hit.shape[0]):
                        pcd_obj = o3d.geometry.PointCloud()
                        pcd_obj.points = o3d.utility.Vector3dVector(sampled_obj_params_rp_hit[i][0].fps_tf)
                        pcd_obj.paint_uniform_color([1, 0.706, 0])
                        pcd_obj_hit = o3d.geometry.PointCloud()
                        pcd_obj_hit.points = o3d.utility.Vector3dVector(ray_surface_pnts[i])
                        pcd_obj_hit.paint_uniform_color([0, 0.651, 0.929])
                        # draw rays
                        vis_hit_mask = ray_seg_gt_hit[i].squeeze(-1)
                        ray_vis_pnts = jnp.concatenate([ray_oriCORN_hit.pos[i][vis_hit_mask], ray_oriCORN_hit.pos[i][vis_hit_mask]+(ray_t_hit_ptb*ray_dir_hit)[i][vis_hit_mask]], axis=-2)
                        ray_vis_idx = jnp.stack([jnp.arange(ray_vis_pnts.shape[-2]//2), jnp.arange(ray_vis_pnts.shape[-2]//2)+ray_vis_pnts.shape[-2]//2], axis=-1)
                        ray_vis_hit = o3d.geometry.LineSet()
                        ray_vis_hit.points = o3d.utility.Vector3dVector(ray_vis_pnts)
                        ray_vis_hit.lines = o3d.utility.Vector2iVector(ray_vis_idx)
                        ray_vis_hit.paint_uniform_color([0.2, 0.251, 0.429])

                        vis_nhit_mask = jnp.logical_not(ray_seg_gt_hit[i].squeeze(-1))
                        ray_vis_pnts = jnp.concatenate([ray_oriCORN_hit.pos[i][vis_nhit_mask], ray_oriCORN_hit.pos[i][vis_nhit_mask]+(ray_t_hit_ptb*ray_dir_hit)[i][vis_nhit_mask]], axis=-2)
                        ray_vis_idx = jnp.stack([jnp.arange(ray_vis_pnts.shape[-2]//2), jnp.arange(ray_vis_pnts.shape[-2]//2)+ray_vis_pnts.shape[-2]//2], axis=-1)
                        ray_vis_nhit = o3d.geometry.LineSet()
                        ray_vis_nhit.points = o3d.utility.Vector3dVector(ray_vis_pnts)
                        ray_vis_nhit.lines = o3d.utility.Vector2iVector(ray_vis_idx)
                        ray_vis_nhit.paint_uniform_color([0.8, 0.951, 0.129])

                        o3d.visualization.draw_geometries([pcd_obj, pcd_obj_hit, ray_vis_hit, ray_vis_nhit])

                if args.grad_loss_weight!=0:
                    seg_fps_grad, (seg_logit, seg_patch_logit, _) = jax.grad(fps_grad_func, has_aux=True)(jnp.zeros_like(ray_oriCORN_hit.rel_fps), 
                                                                                        sampled_obj_params_rp_hit, ray_oriCORN_hit,
                                                                                    line_segment_B=ray_t_hit_ptb*ray_dir_hit, reduce_k=args.col_reduce_k//reduce_k_devider, jkey=subkey, train=train)
                    seg_grad_loss = cal_grad_loss(seg_fps_grad)
                else:
                    seg_logit, seg_patch_logit, _ = col_dec.apply(params[0], sampled_obj_params_rp_hit, ray_oriCORN_hit, line_segment_B=ray_t_hit_ptb*ray_dir_hit,
                                                    reduce_k=args.col_reduce_k//reduce_k_devider, jkey=subkey, train=train)
                    seg_grad_loss = 0
                ray_seg_loss_ += binary_focal_loss(seg_logit, ray_seg_gt_hit, alpha=0.5)
                seg_patch_label = jax.nn.one_hot(ray_hitting_fps_idx_hit, num_classes=sampled_obj_params_rp_hit.nfps, axis=-1)
                seg_patch_label = jnp.where(ray_seg_gt_hit, seg_patch_label, 0)
                ray_seg_loss_ += binary_focal_loss(seg_patch_logit, seg_patch_label, alpha=0.8, apply_col_loss_mask=1)
                ray_seg_loss_ += seg_grad_loss*args.grad_loss_weight
                seg_acc = jnp.mean((ray_seg_gt_hit).astype(jnp.int32) == (seg_logit > 0.).astype(jnp.int32))
                col_acc_dict = {**col_acc_dict, f'ray_hit_seg_acc_{idx}': seg_acc, f'seg_hit_grad_loss_{idx}': seg_grad_loss}

            ray_loss += ray_seg_loss_
            

        # add regularization loss - mean / variance of z - covariance regularizer
        z_flat_vec = sampled_obj_params.z_flat.reshape(-1, sampled_obj_params.z_flat.shape[-1])
        z_flat_norm = jnp.mean(jnp.linalg.norm(z_flat_vec, axis=-1))
        reg_loss = 0
        if args.reg_loss_weight!=0:
            N, d = z_flat_vec.shape
            z_mean = jnp.mean(z_flat_vec, axis=0)
             # Center the vectors by subtracting the mean
            # centered_vectors = z_flat_vec - z_mean  # Shape: (N, d)
            centered_vectors = z_flat_vec # Shape: (N, d)
            # Compute the covariance matrix (N-1 normalization for unbiased estimate)
            cov_matrix = jnp.einsum('ij,ik->jk', centered_vectors, centered_vectors) / (N - 1)  # Shape: (d, d)
            # Compute the Frobenius norm of the difference between covariance and identity matrix
            cov_reg_loss = jnp.sum((cov_matrix - jnp.eye(d)) ** 2)
            mean_reg_loss = jnp.sum(z_mean**2)
            reg_loss = cov_reg_loss + mean_reg_loss

        loss_value = args.sdf_loss_weight*occ_loss + args.col_loss_weight*col_loss + args.col_loss_weight*path_col_loss + \
                len(reduce_k_devider_pool)*args.reg_loss_weight*reg_loss + args.ray_loss_weight*ray_loss

        aux_info = {'sd_loss': occ_loss, 'col_loss': col_loss, 'path_col_loss': path_col_loss, 'ray_loss': ray_loss, 'reg_loss': reg_loss,
                    'z_flat_norm': z_flat_norm,
                    **col_acc_dict, **path_col_acc_dict}
        return loss_value, aux_info

    loss_grad = jax.grad(loss_func, has_aux=True)

    # optimizer = optax.adam(3e-4)
    optimizer = optax.adamw(3e-4)
    opt_state = optimizer.init(params)

    # load checkpoints
    start_itr = 0
    steps_cumulative = 0
    if CKPT_DIR is not None:
        with open(os.path.join(CKPT_DIR, 'save_dict.pkl'), 'rb') as f:
            save_dict = pickle.load(f)
        params_ckpt = save_dict['params']
        opt_state_ckpt = save_dict['opt_state']
        if args.update_sdf_dir is not None:
            # update idx and add more objects
            updated_latent_obj_list = params[-1]
            for i, j in enumerate(old_to_new_idx):
                updated_latent_obj_list = jax.tree_util.tree_map(lambda x, y:x.at[j].set(y[i]), updated_latent_obj_list, params_ckpt[-1])
            params = (*params[:-1], updated_latent_obj_list)
            
            # use new opt state
            del opt_state_ckpt
            print('updated latent objects')
        else:
            new_obj_files = save_dict['obj_filename_list'] if 'obj_filename_list' in save_dict else save_dict['sdf_paths']
            for of, nof in zip(sdf_paths, new_obj_files):
                assert of.split('/')[-1] == nof.split('/')[-1]
                assert of.split('/')[-3] == nof.split('/')[-3]
            if args.use_encoder:
                if params_ckpt[-1] is not None:
                    params = (params_ckpt[0], *params[1:]) # only override decoder
                else:
                    params = params_ckpt
                    opt_state = opt_state_ckpt
            else:
                params = params_ckpt
                opt_state = opt_state_ckpt
        rot_configs = save_dict['rot_configs']
        start_itr = save_dict['itr']
        if 'steps_cumulative' in save_dict:
            steps_cumulative = save_dict['steps_cumulative']
        print(f'loaded from {CKPT_DIR} at iteration {start_itr}')
        del params_ckpt
        del opt_state_ckpt


    @jax.jit
    def loss_update(params, opt_state, jkey, col_data_minibatch, occ_data_minibatch, mulcol_data_minibatch):
        jkey, subkey = jax.random.split(jkey)
        grad, aux_loss = loss_grad(params, subkey, col_data_minibatch, occ_data_minibatch, mulcol_data_minibatch, rot_aug=args.train_rot_aug==1)
        grad = jax.tree_util.tree_map(lambda x: jnp.where(jnp.isnan(x), 0, x), grad)
        # gradient clip
        # grad, _ = optax.clip_by_global_norm(1.0).update(grad, None)
        updates, opt_state = optimizer.update(grad, opt_state, params)
        params = optax.apply_updates(params, updates)
        return params, opt_state, jkey, aux_loss

    def sample_dataset(obj_batch_size, query_size, jkey, ds_type='train'):
        col_ds_ = col_data_list_train if ds_type=='train' else col_data_list_test
        random_col_data_idx = jax.random.randint(jkey, shape=(obj_batch_size,), minval=0, maxval=len(col_ds_))
        selected_col_data:ColDataset = col_ds_[random_col_data_idx]
        selected_col_data = jax.tree_util.tree_map(lambda x: jnp.array(x), selected_col_data)
        occ_dataset_selected = occ_dataset[selected_col_data.obj_idx]
        jkey, subkey = jax.random.split(jkey)
        random_qp_idx = jax.random.randint(subkey, shape=(obj_batch_size, 2, query_size,), minval=0, maxval=DATASET_SIZE)
        if args.use_encoder:
            surface_points = occ_dataset_selected.surface_points
        occ_dataset_selected = jax.tree_util.tree_map(lambda x: jnp.take_along_axis(x, random_qp_idx[...,None], axis=-2), occ_dataset_selected)
        if args.use_encoder:
            jkey, subkey = jax.random.split(jkey)
            random_sp_idx = jax.random.randint(subkey, shape=(obj_batch_size, 2, args.enc_train_npnts,), minval=0, maxval=query_size)
            surface_points = jnp.take_along_axis(surface_points, random_sp_idx[...,None], axis=-2)
            occ_dataset_selected = replace(occ_dataset_selected, surface_points=surface_points)

        selected_mulcol_data = None

        return selected_col_data, occ_dataset_selected, selected_mulcol_data, jkey

    def train_func(params, opt_state, jkey, inner_itr_no=200):
        selected_col_data, occ_dataset_selected, selected_mulcol_data, jkey = sample_dataset(inner_itr_no*OBJ_BATCH_SIZE, QUERY_BATCH_SIZE, jkey, ds_type='train')

        for i in range(inner_itr_no):
            jkey, subkey = jax.random.split(jkey)
            params, opt_state, jkey, aux_loss_ = \
                loss_update(params, opt_state, subkey, 
                            selected_col_data[i*OBJ_BATCH_SIZE:(i+1)*OBJ_BATCH_SIZE], 
                            occ_dataset_selected[i*OBJ_BATCH_SIZE:(i+1)*OBJ_BATCH_SIZE],
                            selected_mulcol_data[i*OBJ_BATCH_SIZE:(i+1)*OBJ_BATCH_SIZE] if selected_mulcol_data is not None else None)
            aux_loss = aux_loss_ if i==0 else jax.tree_util.tree_map(lambda x, y: x+y, aux_loss, aux_loss_)
        aux_loss = jax.tree_util.tree_map(lambda x: x/inner_itr_no, aux_loss)
        return params, opt_state, jkey, aux_loss

    def eval_func(params, selected_col_data, occ_dataset_selected, selected_mulcol_data, jkey):
        _, aux_loss = loss_func(params, jkey, selected_col_data, occ_dataset_selected, selected_mulcol_data, rot_aug=True, train=False)
        return aux_loss
    
    if args.debug:
        eval_func_jit = eval_func
        jkey, subkey = jax.random.split(jkey)
        selected_col_data, occ_dataset_selected, selected_mulcol_data, jkey = sample_dataset(OBJ_BATCH_SIZE, 128, subkey, ds_type='test')
        jkey, subkey = jax.random.split(jkey)
        eval_out = eval_func_jit(params, selected_col_data, occ_dataset_selected, selected_mulcol_data, subkey)
    
    else:
        eval_func_jit = jax.jit(eval_func)

    writer = SummaryWriter(log_dir=LOG_DIR)
    wandb.init(project='rep_pointfeat_training', config=args)

    # save logging files
    logging.basicConfig(filename=os.path.join(LOG_DIR, 'train.log'), level=logging.INFO)
    logging.info(f'logging to {LOG_DIR}')
    logging.info(f'config: {args.__dict__}')
    
    # copy files
    shutil.copyfile(__file__, os.path.join(LOG_DIR, os.path.basename(__file__)))
    BASEDIR = os.path.dirname(__file__)
    shutil.copy(os.path.join(BASEDIR, 'util/model_util.py'), os.path.join(LOG_DIR, 'model_util.py'))

    mesh_dir = os.path.join(LOG_DIR, 'mesh')
    os.makedirs(mesh_dir, exist_ok=True)

    with open(os.path.join(LOG_DIR, 'translation_to_origin.txt'), 'w') as f:
        for mto in mesh_translation_to_origin_list:
            f.write(' '.join([str(m) for m in mto])+'\n')
    with open(os.path.join(LOG_DIR, 'scale_to_origin.txt'), 'w') as f:
        for mso in mesh_scale_to_origin_list:
            f.write(str(mso)+'\n')
    with open(os.path.join(LOG_DIR, 'sdf_dirs.txt'), 'w') as f:
        for of in sdf_paths:
            f.write(of+'\n')
    with open(os.path.join(LOG_DIR, 'obj_dirs.txt'), 'w') as f:
        for of in rel_mesh_obj_path:
            f.write(of+'\n')
    
    ema_params = params
    ema_ratio = 0.995
    for itr in tqdm(range(start_itr, 1000_000)):
        if itr%args.save_interval==0 and itr!=0:
            # save checkpoint
            save_dict = {'params':params, 'ema_params':ema_params, 'rot_configs':rot_configs, 'itr':itr, 'steps_cumulative':steps_cumulative,
                         'mesh_obj_path':rel_mesh_obj_path, 'sdf_paths':sdf_paths,
                         'opt_state':opt_state,
                         'col_dec_arg_dict':{k:col_dec.__dict__[k] for k in col_dec.__dict__ if k[:1]!='_'},
                         'mesh_translation_to_origin_list':mesh_translation_to_origin_list,
                         'mesh_scale_to_origin_list':mesh_scale_to_origin_list,
                         }
            if args.use_encoder:
                save_dict['enc_arg_dict'] = {k:enc.__dict__[k] for k in enc.__dict__ if k[:1]!='_'}
            with open(os.path.join(LOG_DIR, 'save_dict.pkl'), 'wb') as f:
                pickle.dump(save_dict, f)
            if params[-1] is not None:
                with open(os.path.join(LOG_DIR, 'latent_obj_dict.pkl'), 'wb') as f:
                    pickle.dump(jax.tree_util.tree_map(lambda x: np.array(x), params[-1]).__dict__, f)
        ckpt_interval = args.save_interval*4
        if itr%ckpt_interval==0 and itr!=0:
            with open(os.path.join(LOG_DIR, f'save_dict_{itr}.pkl'), 'wb') as f:
                pickle.dump(save_dict, f)

        _, jkey = jax.random.split(jkey)
        params, opt_state, jkey, train_losses = train_func(params, opt_state, jkey, inner_itr_no=args.inner_itr_no)
        if itr > 200:
            ema_params = jax.tree_util.tree_map(lambda x,y: (1-ema_ratio)*x+ema_ratio*y, params, ema_params)
        else:
            ema_params = params
        steps_cumulative += args.inner_itr_no
        if itr%100 == 0:
            jkey, subkey = jax.random.split(jkey)
            selected_col_data, occ_dataset_selected, selected_mulcol_data, jkey = sample_dataset(2*OBJ_BATCH_SIZE, args.query_batch_size, subkey, ds_type='test')
            jkey, subkey = jax.random.split(jkey)
            eval_out = eval_func_jit(ema_params, selected_col_data, occ_dataset_selected, selected_mulcol_data, subkey)
            
            for k in eval_out:
                writer.add_scalar(f'EVAL/{k}', np.array(eval_out[k]), itr)
            for k in train_losses:
                writer.add_scalar(f'TRAIN/{k}', np.array(train_losses[k]), itr)
            
            wandb.log({f"EVAL/{k}": eval_out[k] for k in eval_out}, step=itr)
            wandb.log({f"TRAIN/{k}": train_losses[k] for k in train_losses}, step=itr)
            wandb.log({"TRAIN/steps_cumulative": steps_cumulative}, step=itr)

            if vessl_on:
                eval_log_dict = {"EVAL_"+k: eval_out[k] for k in eval_out}
                train_log_dict = {"TRAIN_"+k: train_losses[k] for k in train_losses}
                vessl.log(step=itr, payload={**eval_log_dict, **train_log_dict})
        