import flax.linen as nn
import typing
import jax.numpy as jnp
import numpy as np
import jax
import einops
import importlib
from flax.core.frozen_dict import FrozenDict
from flax import struct
import flax
import pickle
import time
import os, sys
from typing import Sequence, Tuple
from functools import partial
from dataclasses import replace

from pathlib import Path

BASEDIR = os.path.dirname(os.path.dirname(__file__))
if BASEDIR not in sys.path:
    sys.path.insert(0, BASEDIR)

import util.ev_util.ev_util as eutil
import util.ev_util.ev_layers as evl
import util.ev_util.rotm_util as rmutil
import util.transform_util as tutil
import util.camera_util as cutil
from dinov2_jax.dino_weights import load_vit_params, load_dino_vits
import util.dif_model_util as dmutil
import util.broad_phase as broad_phase

# try:
#     from dinov2_jax.dino_weights import load_vit_params, load_dino_vits
#     import util.dif_model_util as dmutil
# except:
#     print('not using dino')
import util.latent_obj_util as loutil
import util.ev_util.ev_layers as ev_layers
import util.asset_util as asutil
# import util.kdtree as kdtree
# import util.ann as ann
from util.dotenv_util import REP_CKPT, REP_CKPT_ITR


def load_canonical_objects(obj_path):
    '''
    obj_path: path to the directory containing checkpoint filess
    '''
    with open(os.path.join(obj_path, 'sdf_dirs.txt'), 'r') as f:
        sdf_dirs = np.array([line.strip() for line in f.readlines()])
    obj_names = [os.path.basename(sd).split('.')[0] for sd in sdf_dirs]
    dataset_names = [Path(sd).parent.parent.name for sd in sdf_dirs]
    with open(os.path.join(obj_path, 'scale_to_origin.txt'), 'r') as f:
        scale_to_origin = np.array([float(line.strip()) for line in f.readlines()])
    with open(os.path.join(obj_path, 'translation_to_origin.txt'), 'r') as f:
        translation_to_origin = np.array([np.fromstring(line.strip(), sep=" ") for line in f.readlines()])
    with open(os.path.join(obj_path, 'latent_obj_dict.pkl'), 'rb') as f:
        latent_obj_dict = pickle.load(f)
    canonical_latent_objects = loutil.LatentObjects(**latent_obj_dict)
    mesh_aligned_canonical_latent_objects = canonical_latent_objects.apply_scale(1/scale_to_origin).translate(-translation_to_origin)
    return canonical_latent_objects, mesh_aligned_canonical_latent_objects, obj_names, dataset_names

@struct.dataclass
class Models:
    pretrain_ckpt_id: str = None
    dif_args: typing.NamedTuple=None
    rot_configs: typing.Sequence=None
    latent_shape: typing.Sequence=None
    pixel_size: typing.Sequence[int]=None

    encoder_params: FrozenDict = None
    encoder_batch_stats: FrozenDict = None
    encoder_model: nn.Module = None

    col_decoder_params: FrozenDict = None
    col_decoder_batch_stats: FrozenDict = None
    col_decoder_model: nn.Module = None
    
    sdf_decoder_params: FrozenDict = None
    sdf_decoder_batch_stats: FrozenDict = None
    sdf_decoder_model: nn.Module = None

    ray_decoder_params: FrozenDict = None
    ray_decoder_batch_stats: FrozenDict = None
    ray_decoder_model: nn.Module = None
    
    img_encoder_params: FrozenDict = None
    img_encoder_batch_stats: FrozenDict = None
    img_encoder_model: nn.Module = None

    spatial_PE_params: FrozenDict = None
    spatial_PE_batch_stats: FrozenDict = None
    spatial_PE_model: nn.Module = None

    denoiser_params: FrozenDict = None
    denoiser_batch_stats: FrozenDict = None
    denoiser_model: nn.Module = None

    seg_predictor_params: FrozenDict = None
    seg_predictor_batch_stats: FrozenDict = None
    seg_predictor_model: nn.Module = None

    self_collision_checker_params: FrozenDict = None
    self_collision_checker_batch_stats: FrozenDict = None
    self_collision_checker_model: nn.Module = None

    feature_decoder_params: FrozenDict = None
    feature_decoder_batch_stats: FrozenDict = None
    feature_decoder_model: nn.Module = None

    feature_renderer_params: FrozenDict = None
    feature_renderer_batch_stats: FrozenDict = None
    feature_renderer_model: nn.Module = None

    canonical_latent_obj:loutil.LatentObjects = None
    canonical_latent_obj_filename_list: typing.List[str] = None

    scale_to_origin: np.ndarray = None
    translation_to_origin: np.ndarray = None

    trainable_mask: FrozenDict = None

    learnable_queries: FrozenDict =None

    asset_path_util: asutil.AssetPaths = None

    @property
    def nh(self)->jnp.ndarray:
        return self.latent_shape[0] * self.latent_shape[1] * self.latent_shape[2] + 3 + 3 * self.latent_shape[0]

    def load_pretrained_models(self, save_dir=None, ckpt_itr=None, download_assets=True)->"Models":
        if save_dir is None:
            save_dir = REP_CKPT
        try:
            ckpt_itr = int(REP_CKPT_ITR)
        except:
            pass
        if ckpt_itr is None or ckpt_itr<=0:
            ckpt_itr = ''
        else:
            ckpt_itr = f'_{ckpt_itr}'
        with open(os.path.join(save_dir, f"save_dict{ckpt_itr}.pkl"), "rb") as f:
            save_dict = pickle.load(f)
        print(f'loading pretrained models from {save_dir} itr {ckpt_itr}')
        ckpt_id = save_dir.split('/')[-1]
        self = self.replace(pretrain_ckpt_id=ckpt_id)

        # sdf_dirs = save_dict["obj_filename_list"]
        sdf_dirs = save_dict["obj_filename_list"] if "obj_filename_list" in save_dict else save_dict["sdf_paths"]
        oricorn_dec_params = save_dict["params"]
        rot_configs = save_dict["rot_configs"]

        col_decoder = ColDecoderSceneV2(**save_dict['col_dec_arg_dict'])

        if len(oricorn_dec_params) == 5:
            sdf_dec_params, col_dec_params, ray_dec_params, point_latent_obj, latent_obj_list = oricorn_dec_params
            params = {'sdf_decoder': sdf_dec_params, 'col_decoder': col_dec_params, 'ray_decoder': ray_dec_params}
        else:
            col_dec_params, point_latent_obj, latent_obj_list = oricorn_dec_params
            params = {'col_decoder': col_dec_params}

        # if load_self_collision_model:
        #     # load self collision model
        #     print('loading self collision model')
        #     from train_shakey_self_collision_net import ShakeySelfCollisionNet
        #     self_collision_checker_model = ShakeySelfCollisionNet()
        #     with open('rep_ckpt/self_collision_shakey/model_epoch_200.ckpt', 'rb') as f:
        #     # with open('rep_ckpt/self_collision_im2/model_epoch_60.ckpt', 'rb') as f:
        #         ckpt_bytes = f.read()
        #     initial_params = self_collision_checker_model.init(jax.random.PRNGKey(0), jnp.ones([1, 6]))['params']
        #     self_collision_checker_params = flax.serialization.from_bytes(initial_params, ckpt_bytes)
        #     self_collision_checker_params = {'params':self_collision_checker_params}
        #     self = self.replace(self_collision_checker_params=self_collision_checker_params, self_collision_checker_model=self_collision_checker_model)

        with open(os.path.join(save_dir, 'scale_to_origin.txt'), 'r') as f:
            scale_to_origin = np.array([float(line.strip()) for line in f.readlines()])
        with open(os.path.join(save_dir, 'translation_to_origin.txt'), 'r') as f:
            translation_to_origin = np.array([np.fromstring(line.strip(), sep=" ") for line in f.readlines()])
        self = self.replace(scale_to_origin=scale_to_origin, translation_to_origin=translation_to_origin)

        self = self.replace(asset_path_util=asutil.AssetPaths(save_dir, download_assets=download_assets))

        names = self.pretraining_names
        # models = [sdf_decoder, col_decoder, ray_decoder]
        models = [col_decoder]
        self = self.replace(rot_configs=rot_configs)
        self = self.replace(canonical_latent_obj_filename_list=sdf_dirs)
        self = self.replace(canonical_latent_obj=latent_obj_list.init_pos_zero())
        self = self.replace(**{name+'_model': model for name, model in zip(names, models)})
        self = self.replace(latent_shape=latent_obj_list.latent_shape)
        self = self.set_params(params)
        return self


    def load_self_collision_model(self, robot)->"Models":
        # load self collision model
        print('loading self collision model')
        from train_shakey_self_collision_net import ShakeySelfCollisionNet
        self_collision_checker_model = ShakeySelfCollisionNet()
        ckpt_dir = {'shakey':'checkpoints/rep_ckpt/self_collision_shakey/model_shakey_epoch_40.ckpt',
                    'RobotBimanualV4': 'checkpoints/rep_ckpt/self_collision_im2/model_IM2_epoch_200.ckpt',
                    'im2':'checkpoints/rep_ckpt/self_collision_im2/model_IM2_epoch_200.ckpt',
                    'ur5':'checkpoints/rep_ckpt/self_collision_ur5/model_ur5_epoch_100.ckpt',
                    'shakey_robotiq':'checkpoints/rep_ckpt/self_collision_shakey_robotiq/model_shakey_robotiq_epoch_70.ckpt'}[robot]
        with open(ckpt_dir, 'rb') as f:
            ckpt_bytes = f.read()
        initial_params = self_collision_checker_model.init(jax.random.PRNGKey(0), jnp.ones([1, 6]))['params']
        self_collision_checker_params = flax.serialization.from_bytes(initial_params, ckpt_bytes)
        self_collision_checker_params = {'params':self_collision_checker_params}
        self = self.replace(self_collision_checker_params=self_collision_checker_params, self_collision_checker_model=self_collision_checker_model)
        return self

    def load_dino_params(self):
        if self.dif_args.dino_size not in ['s', 'b', 'l']:
            return self
        
        jkey = jax.random.PRNGKey(0)
        rgbs = np.zeros((1, 238, 420, 3), dtype=np.uint8)
        cam_posquat = np.zeros((1, 7), dtype=np.float32)
        cam_intrinsic = np.zeros((1, 6), dtype=np.float32)
        img_feat_structs, imgenc_params = self.img_encoder_model.init_with_output(jkey, rgbs, cam_posquat, cam_intrinsic)
        _, jkey = jax.random.split(jkey)

        imgenc_params_dino = {'params':load_vit_params(imgenc_params['params'], None, dino_size=self.dif_args.dino_size)}
        # imgenc_params_source = {'params':load_vit_params(self.img_encoder_params['params'], None)}
        self.img_encoder_params['params']['DinoViT_0'] = imgenc_params_dino['params']['DinoViT_0']

        # jax_params_flat, jax_param_pytree = jax.tree_util.tree_flatten_with_path(imgenc_params_dino)
        # jax_params_flat, jax_param_pytree = jax.tree_util.tree_flatten_with_path(imgenc_params_source)


        # self = self.replace(img_encoder_params=imgenc_params)
        return self

    def init_emb(self, outer_shape):
        nfps = self.latent_shape[0]
        return (jnp.zeros(outer_shape + (self.dif_args.base_dim,)), jnp.zeros(outer_shape + (nfps, self.dif_args.base_dim,)),
                jnp.zeros(outer_shape + (nfps, self.dif_args.base_dim,)), jnp.zeros(outer_shape + (nfps, self.dif_args.base_dim,)))

    def init_dif_model_scenedata(self, args, ds)->"Models":

        jkey = jax.random.PRNGKey(args.seed+41)

        if args.dino_size in ['s', 'b', 'l']:
            imgenc_model = dmutil.ImageFeatureEntireV2(args.dino_size)
        else:
            imgenc_model = dmutil.ImageFeatureEntireCNN()
        denoise_model = dmutil.DenoisingModel(args, self.rot_configs)
        spatial_PE_model = dmutil.SpatialPE()
        if args.render_loss_weight != 0:
            oriCORN_renderer = dmutil.oriCORNRenderer(128, self.rot_configs)
        else:
            oriCORN_renderer = None

        # init all variables
        rgbs = ds["rgbs"]
        pixel_size = rgbs.shape[-3:-1]
        cam_info = ds["cam_info"]
        obj_info = ds["obj_info"]
        latent_obj = loutil.LatentObjects().init_obj_info(obj_info, self.mesh_aligned_canonical_obj, self.rot_configs)
        _, jkey = jax.random.split(jkey)
        img_feat_structs, imgenc_params = imgenc_model.init_with_output(jkey, rgbs, cam_info["cam_posquats"].astype(jnp.float32), cam_info["cam_intrinsics"].astype(jnp.float32), dino_feat_out=True)
        _, jkey = jax.random.split(jkey)
        img_feat_structs, spatial_PE_params = spatial_PE_model.init_with_output(jkey, img_feat_structs)
        if oriCORN_renderer is not None:
            feature_renderer_params = oriCORN_renderer.init(jkey, latent_obj, img_feat_structs, cam_info["cam_posquats"][...,-1,:].astype(jnp.float32), 
                                                            img_feat_structs.intrinsic[...,-1,:].astype(jnp.float32), output_pixel_size=pixel_size)
        else:
            feature_renderer_params = None
        time_cond = jnp.ones((latent_obj.outer_shape[0], ))
        self = self.replace(pixel_size = rgbs.shape[-3:-1])
        self = self.replace(latent_shape=latent_obj.latent_shape)
        _, jkey = jax.random.split(jkey)
        if args.train_seg:
            seg_model = dmutil.SegModel(args)
            seg_params = seg_model.init(jkey, img_feat_structs)
        else:
            seg_model = None
            seg_params = None
        
        # import torch
        imgenc_params = {'params':load_vit_params(imgenc_params['params'], None, args.dino_size)}
            
        denoiser_params = denoise_model.init({'params':jkey, 'dropout':jkey}, latent_obj, img_feat_structs, time_cond)
        if args.dm_type == 'regression':
            learnable_queries = jax.random.normal(jkey, shape=(args.nparticles, self.nh), dtype=jnp.float32)
            self = self.replace(learnable_queries=learnable_queries)
        elif args.dm_type == 'fm_cat':
            learnable_queries = jax.random.normal(jkey, shape=(2*args.nparticles, self.nh), dtype=jnp.float32)
            self = self.replace(learnable_queries=learnable_queries)
        elif args.dm_type in ['fm_reg', 'fm_vel']:
            learnable_queries = jax.random.normal(jkey, shape=(args.nparticles, self.nh), dtype=jnp.float32)
            learnable_queries_obj = loutil.LatentObjects().init_h(learnable_queries, self.latent_shape)
            learnable_queries_obj = learnable_queries_obj.replace(z=2*learnable_queries_obj.z)
            learnable_queries = learnable_queries_obj.h
            self = self.replace(learnable_queries=learnable_queries)

        # if args.train_seg_only:
        #     spatial_PE_model = None
        #     denoise_model = None
        #     oriCORN_renderer = None
        #     spatial_PE_params = None
        #     denoiser_params = None
        #     feature_renderer_params = None

        names = self.difmodel_names
        models = [imgenc_model, spatial_PE_model, denoise_model, seg_model, oriCORN_renderer]
        params = (imgenc_params, spatial_PE_params, denoiser_params, seg_params, feature_renderer_params)
        self = self.replace(**{name+'_model': model for name, model in zip(names, models)})
        self = self.replace(**{name+'_params': model for name, model in zip(names, params)})
        self = self.replace(dif_args=args)

        # split trainable and frozen parameters
        freeze_key = ['DinoViT']

        def set_trainable_mask(path, x):
            for fk in freeze_key:
                return False if any(fk in key.key for key in path) else True
        trainable_mask = jax.tree_util.tree_map_with_path(set_trainable_mask, self.params)
        self = self.replace(trainable_mask=trainable_mask)

        # test parameters
        # trainable_params = self.trainable_params
        # test_self = self.set_params(trainable_params, update_trainable_only=True)

        return self


    def init_img_feat_models(self, args, ds)->"Models":

        jkey = jax.random.PRNGKey(args.seed+41)

        imgenc_model = dmutil.ImageFeatureEntireV2(args.dino_size)
        spatial_PE_model = dmutil.SpatialPE()
        oriCORN_renderer = dmutil.oriCORNRenderer(args.base_dim, self.rot_configs)
        feature_decoder = dmutil.FeatureDecoder(args.base_dim)

        # init all variables
        rgbs = ds["rgbs"]
        pixel_size = rgbs.shape[-3:-1]
        self = self.replace(pixel_size = pixel_size)
        cam_info = ds["cam_info"]
        obj_info = ds["obj_info"]
        latent_obj = loutil.LatentObjects().init_obj_info(obj_info, self.mesh_aligned_canonical_obj, self.rot_configs)
        _, jkey = jax.random.split(jkey)
        img_feat_structs, imgenc_params = imgenc_model.init_with_output(jkey, rgbs, cam_info["cam_posquats"].astype(jnp.float32), cam_info["cam_intrinsics"].astype(jnp.float32))
        _, jkey = jax.random.split(jkey)
        img_feat_structs, spatial_PE_params = spatial_PE_model.init_with_output(jkey, img_feat_structs)

        feature_renderer_params = oriCORN_renderer.init(jkey, latent_obj, img_feat_structs, cam_info["cam_posquats"][...,-1,:].astype(jnp.float32), 
                                                        img_feat_structs.intrinsic[...,-1,:].astype(jnp.float32), output_pixel_size=pixel_size)
        feature_decoder_params = feature_decoder.init(jkey, img_feat_structs, pixel_size)

        # import torch
        imgenc_params = {'params':load_vit_params(imgenc_params['params'], None, args.dino_size)}

        self = self.replace(latent_shape=latent_obj.latent_shape)
        names = ['img_encoder', 'spatial_PE', 'feature_renderer', 'feature_decoder']
        models = [imgenc_model, spatial_PE_model, oriCORN_renderer, feature_decoder]
        params = (imgenc_params, spatial_PE_params, feature_renderer_params, feature_decoder_params)
        self = self.replace(**{name+'_model': model for name, model in zip(names, models)})
        self = self.replace(**{name+'_params': model for name, model in zip(names, params)})
        self = self.replace(dif_args=args)

        # split trainable and frozen parameters
        freeze_key = ['DinoViT']

        def set_trainable_mask(path, x):
            for fk in freeze_key:
                return False if any(fk in key.key for key in path) else True
        trainable_mask = jax.tree_util.tree_map_with_path(set_trainable_mask, self.params)
        # trainable_mask = jax.tree_util.tree_map_with_path(set_trainable_mask, self.estimator_params)
        self = self.replace(trainable_mask=trainable_mask)

        return self


    def cal_statics(self):

        jkey = jax.random.PRNGKey(71)
        test_nb = 1
        test_nv = 3
        test_no = 7
        rgb_dummy = jax.random.randint(jkey, shape=(test_nb, test_nv, *self.pixel_size, 3), minval=0, maxval=255, dtype=jnp.uint8)
        cam_posquats_dummy = jax.random.normal(jkey, shape=(test_nb, test_nv, 7), dtype=jnp.float32)
        cam_intrinsic_dummy = jax.random.normal(jkey, shape=(test_nb, test_nv, 6), dtype=jnp.float32)
        latent_obj_dummy = loutil.LatentObjects().init_h(jax.random.normal(jkey, shape=(test_nb, test_no, self.nh), dtype=jnp.float32), self.latent_shape)
        time_dummy = jax.random.uniform(jkey, shape=(test_nb,))
        
        # network info out
        def cal_num_params(params_):
            return np.sum(jax.tree_map(lambda x: np.prod(x.shape), jax.tree_util.tree_flatten(params_)[0]))
        imgenc_num_params = cal_num_params(self.img_encoder_params)
        denoiser_num_params = cal_num_params(self.denoiser_params)

        enc_app = jax.jit(self.img_encoder_model.apply)
        den_app = jax.jit(self.denoiser_model.apply)
        def test_time(func, inputs):
            for i in range(101):
                if i==1:
                    stim = time.time()
                y = func(*inputs, rngs={'dropout':jkey})
                y = jax.block_until_ready(y)
            tim_ = (time.time() - stim)
            tim_ = tim_/100
            return tim_, y
        
        tim_enc, img_feat_struct_dummy = test_time(enc_app, (self.img_encoder_params, rgb_dummy, cam_posquats_dummy, cam_intrinsic_dummy))
        img_feat_struct_dummy = self.apply('spatial_PE', img_feat_struct_dummy, train=False)
        tim_dec, _ = test_time(den_app, (self.denoiser_params, latent_obj_dummy, img_feat_struct_dummy, time_dummy))

        print(f'img enc num params: {imgenc_num_params//1000}k, denoiser: {denoiser_num_params//1000}k')
        print(f'duration img_enc: {tim_enc*1000}ms, denoiser: {tim_dec*1000}ms')

        return {'img_enc_num_params':imgenc_num_params, 'img_enc_time':tim_enc, 'den_num_params':denoiser_num_params, 'den_time':tim_dec}


    def apply(self, name, *args, train=False, **kwargs):
        model = getattr(self, name+'_model')
        params = getattr(self, name+'_params')
        if not train:
            params = jax.lax.stop_gradient(params)
        batch_stats = getattr(self, name+'_batch_stats')
        return model.apply(params, train=train, *args, **kwargs)
        
    def set_params(self, params:typing.Dict, update_trainable_only=False)->"Models":
        if update_trainable_only:
            params = jax.tree_util.tree_map(lambda mask, x, y: x if mask else y, self.trainable_mask, params, self.params)
        self = self.replace(**{k+'_params': params[k] for k in params if k!='learnable_queries'})
        self = self.replace(learnable_queries=params['learnable_queries'] if 'learnable_queries' in params else None)
        return self

    def pairwise_collision_prediction(self, objsA:loutil.LatentObjects, objsB:loutil.LatentObjects, jkey, reduce_k=16, train=False):
        '''
        objsA: outer shape (... NA)
        objsB: outer shape (... NB)
        return collsion results (... NA NB)
        '''
        na = objsA.outer_shape[-1]
        nb = objsB.outer_shape[-1]
        valid_objA_mask = jnp.all(objsA.pos<8.0, axis=-1) # objects which are outside of 8-m box are invalid
        valid_objB_mask = jnp.all(objsB.pos<8.0, axis=-1)
        objsA = objsA.extend_and_repeat_outer_shape(nb, -1)
        objsB = objsB.extend_and_repeat_outer_shape(na, -2)
        # objsAB = objsA.stack(objsB, -1)
        # col_res = self.apply('col_predictor', objsAB, jkey=jkey, reduce_k=reduce_k, train=train)
        col_res, _, _ = self.col_decoder_model.apply(self.col_decoder_params, objsA, objsB, reduce_k=reduce_k, jkey=jkey, train=train)
        col_res = jnp.where(valid_objA_mask[...,None], col_res, -100)
        col_res = jnp.where(valid_objB_mask[...,None,:], col_res, -100)
        return col_res

    def occ_prediction(self, latent_obj:loutil.LatentObjects, qpnts:jnp.ndarray, jkey, reduce_k=16, merge=False, train=False):
        '''
        latent_obj: outer shape (... N)
        qpnts: (..., nq, 3)
        return occupancy prediction (..., nq)
        '''
        extended = False
        if qpnts.ndim == 1:
            extended = True
            qpnts = qpnts[None]
        if merge:
            latent_obj = latent_obj.merge(keepdims=True)
        nq = qpnts.shape[-2]
        point_oriCORN:loutil.LatentObjects = latent_obj.replace(z = jnp.zeros_like(latent_obj.z[...,:1,:,:]), pos = jnp.zeros_like(latent_obj.pos), rel_fps = jnp.zeros_like(latent_obj.rel_fps[...,:1,:]))
        point_oriCORN = point_oriCORN.extend_and_repeat_outer_shape(nq, -1)
        point_oriCORN = point_oriCORN.translate(qpnts)
        # latent_obj = latent_obj.extend_and_repeat_outer_shape(nq, -1)
        latent_obj = latent_obj.extend_and_repeat_outer_shape(1, -1)

        occ_pred1, _, occ_pred2 = self.col_decoder_model.apply(self.col_decoder_params, latent_obj, point_oriCORN, 
                                                reduce_k=reduce_k, jkey=jkey, train=train)
        if extended:
            occ_pred1 = occ_pred1.squeeze(-2)
        return occ_pred1

    def ray_prediction(self, target_oriCORN:loutil.LatentObjects, p0, dir, jkey, reduce_k=16, depth_multiplier=10.0, merge=False, train=False):
        '''
        target_oriCORN: (..., NO, 7)
        p0: (..., NQ, 3)
        dir: (..., NQ, 3)
        return seg: (..., NQ), depth: (..., NQ)
        '''
        if merge:
            target_oriCORN = target_oriCORN.merge(keepdims=True)
        query_outer_shape = jnp.broadcast_shapes(p0.shape[:-1], dir.shape[:-1])
        p0 = jnp.broadcast_to(p0, query_outer_shape+(3,))
        dir = jnp.broadcast_to(dir, query_outer_shape+(3,))*depth_multiplier
        nq = p0.shape[-2]
        jkey, subkey = jax.random.split(jkey)
        ray_oriCORN:loutil.LatentObjects = target_oriCORN.replace(z = jnp.zeros_like(target_oriCORN.z[...,:1,:,:]), pos = jnp.zeros_like(target_oriCORN.pos), rel_fps = jnp.zeros_like(target_oriCORN.rel_fps[...,:1,:]))
        ray_oriCORN = ray_oriCORN.extend_and_repeat_outer_shape(nq, -1)
        ray_oriCORN = ray_oriCORN.translate(p0)
        # target_oriCORN = target_oriCORN.extend_and_repeat_outer_shape(nq, -1)
        target_oriCORN = target_oriCORN.extend_and_repeat_outer_shape(1, -1)

        seg_logit, seg_patch_logit, _ = self.col_decoder_model.apply(self.col_decoder_params, target_oriCORN, ray_oriCORN, line_segment_B=dir,
                                            reduce_k=reduce_k, jkey=subkey, train=train)

        hitting_mask = (seg_patch_logit>0).squeeze(-1)

        depth_per_fps = jnp.linalg.norm(p0[...,None,:] - target_oriCORN.fps_tf, axis=-1)
        depth_per_fps = jnp.where(hitting_mask, depth_per_fps, 10)
        depth = jnp.min(depth_per_fps, axis=-1, keepdims=True)

        # approximated depth by fps
        # max_idx = jnp.argmax(seg_patch_logit, axis=-2, keepdims=True)
        # hittng_fps = jnp.take_along_axis(target_oriCORN.fps_tf, max_idx, axis=-2).squeeze(-2)
        # depth = jnp.linalg.norm(p0 - hittng_fps, axis=-1)

        return seg_logit, depth

    def get_model_apply_func(self, spatial_PE=False, jit=True):
        def model_apply_param_jit(x, cond, t, cond_mask, jk, previous_emb):
            if spatial_PE:
                cond = self.apply('spatial_PE', cond, train=False)
            return self.apply('denoiser', x, cond, t, cond_mask, previous_emb=previous_emb, rngs={'dropout':jk})
        if jit:
            return jax.jit(model_apply_param_jit)
        else:
            return model_apply_param_jit


    @property
    def names(self):
        return [k[:-6] for k in self.__annotations__ if (k[-5:]=='model' and k[:-6] != 'self_collision_checker')]

    @property
    def pretraining_names(self):
        # return ['sdf_decoder', 'col_decoder', 'ray_decoder']
        return ['col_decoder']

    @property
    def difmodel_names(self):
        return['img_encoder', 'spatial_PE', 'denoiser', 'seg_predictor', 'feature_renderer']

    @property
    def params(self):
        return {**{name: getattr(self, name+'_params') for name in self.names}, **{'learnable_queries': self.learnable_queries}}
    

    @property
    def estimator_params(self):
        return {**{name: getattr(self, name+'_params') for name in self.difmodel_names}, **{'learnable_queries': self.learnable_queries}}

    @property
    def pretraining_params(self):
        return {name: getattr(self, name+'_params') for name in self.pretraining_names}

    @property
    def trainable_params(self):
        return jax.tree_util.tree_map(lambda mask, x: x if mask else None, self.trainable_mask, self.params)
    

    @property
    def mesh_aligned_canonical_obj(self)->loutil.LatentObjects:
        return self.canonical_latent_obj.apply_scale(1/self.scale_to_origin).init_pos_zero().translate(-self.translation_to_origin)



def visualize_collision(reduced_A, reduced_B, line_segment_B, time_B, fixed_oriCORN_A, canonical_oriCORNs_B, pqc_path_B, pairwise_feat):
    '''
    line_segment_B: (NQ, NFPSA_reduced, NFPSB_reduced, 3)
    time_B: (NQ, NFPSA_reduced, NFPSB_reduced)
    pqc_path_B: (NQ, NAC, NOB, 3)
    pairwise_feat: (NQ, NFPSA_reduced, NFPSB_reduced, 1)
    '''
    import open3d as o3d
    from util.reconstruction_util import create_scene_mesh_from_oriCORNs, create_swept_volume_from_oriCORNs

    fps_tf_A = fixed_oriCORN_A.fps_tf

    fps_path_B = tutil.pq_action(pqc_path_B[...,None,:3], pqc_path_B[...,None,3:], canonical_oriCORNs_B.fps_tf) # (NQ, NAC, NOB, NFPSB, 3)
    fps_path_B_merged = einops.rearrange(fps_path_B, ' ... i j k p -> ... i (j k) p') # (NQ, NAC, NOB*NFPSB, 3)
    fps_path_B_merged_seq = jnp.stack([fps_path_B_merged[...,1:,:,:], fps_path_B_merged[...,:-1,:,:]], axis=-1) # (NQ, NAC, )
    

    for i in range(fps_path_B_merged_seq.shape[0]):
        sphere_o3ds = []
        if reduced_A.ndim == 1:
            rec_mesh_fixed = create_scene_mesh_from_oriCORNs(reduced_A[i], qp_bound=5.0, density=400, ndiv=800, visualize=False)
            # specify collision position with spheres
            for j in range(pairwise_feat.shape[1]):
                if np.any(pairwise_feat[i,j] > 0):
                    sphere_o3d = o3d.geometry.TriangleMesh.create_sphere(radius=0.05)
                    sphere_o3d.translate(reduced_A[i].fps_tf.reshape(-1,3)[j])
                    sphere_o3d.paint_uniform_color([0, 1, 0])
                    sphere_o3ds.append(sphere_o3d)
        else:
            if i==0:
                rec_mesh_fixed = create_scene_mesh_from_oriCORNs(reduced_A, qp_bound=5.0, density=400, ndiv=800, visualize=False)
            # specify collision position with spheres
            for j in range(pairwise_feat.shape[1]):
                if np.any(pairwise_feat[i,j] > 0):
                    sphere_o3d = o3d.geometry.TriangleMesh.create_sphere(radius=0.05)
                    sphere_o3d.translate(reduced_A.fps_tf.reshape(-1,3)[j])
                    sphere_o3d.paint_uniform_color([0, 1, 0])
                    sphere_o3ds.append(sphere_o3d)


        link_fps_seq_original_ = fps_path_B_merged_seq[i]
        link_fps_seq_flat = link_fps_seq_original_.reshape(-1, link_fps_seq_original_.shape[-2], link_fps_seq_original_.shape[-1])
        link_fps_seq_flat = jnp.moveaxis(link_fps_seq_flat, -1, -2)
        fixed_idx = np.stack([2*np.arange(link_fps_seq_flat.shape[0]), 2*np.arange(link_fps_seq_flat.shape[0])+1], axis=-1).astype(np.int32)
        line = o3d.geometry.LineSet()
        line.points = o3d.utility.Vector3dVector(link_fps_seq_flat.reshape(-1,3))
        line.lines = o3d.utility.Vector2iVector(fixed_idx)
        line.paint_uniform_color([1, 0, 0])

        pcd_A_entire = o3d.geometry.PointCloud()
        pcd_A_entire.points = o3d.utility.Vector3dVector(fps_tf_A.reshape(-1,3))
        pcd_A_entire.paint_uniform_color([0.01, 0, 0])
        pcd_A = o3d.geometry.PointCloud()
        pcd_A.points = o3d.utility.Vector3dVector(reduced_A.fps_tf[i].reshape(-1,3))
        pcd_A.paint_uniform_color([1, 0, 0])
        pcd_time_opt = o3d.geometry.PointCloud()
        pcd_time_opt.points = o3d.utility.Vector3dVector(reduced_B.fps_tf[i].reshape(-1,3))
        pcd_time_opt.paint_uniform_color([0, 0, 1])
        frames = o3d.geometry.TriangleMesh.create_coordinate_frame(size=0.6, origin=[0, 0, 0])

        rec_mesh = create_swept_volume_from_oriCORNs(reduced_B[i], line_segment_B[i], time_B[i], qp_bound=1.8, density=200, ndiv=400, visualize=False)

        # 

        rec_mesh.paint_uniform_color([0, 1, 1])
        rec_mesh_fixed.paint_uniform_color([1, 1, 0])

        o3d.visualization.draw_geometries([line, pcd_A_entire, pcd_A, pcd_time_opt, rec_mesh, rec_mesh_fixed, frames])
        # o3d.visualization.draw_geometries([line, pcd_A_entire, pcd_A, pcd_time_opt, frames])

        # return [line, pcd_A_entire, pcd_A, pcd_time_opt, rec_mesh, rec_mesh_fixed, frames]
        return [line, pcd_A_entire, pcd_A, pcd_time_opt, frames, *sphere_o3ds]


# @partial(jax.jit, static_argnums=[1])
def FPS_padding(pnts, k, jkey):
    if k==1:
        return jnp.mean(pnts, axis=-2, keepdims=True)
    selected_pnt_set = pnts[...,0:1,:]

    # initial point
    dist = jnp.sum((pnts[...,None,:] - selected_pnt_set[...,None,:,:])**2, axis=-1)
    dist_min_values = jnp.min(dist, axis=-1)
    max_idx = jnp.argmax(dist_min_values, axis=-1)
    selected_pnt = jnp.take_along_axis(pnts, max_idx[...,None,None], axis=-2).squeeze(-2)
    selected_pnt_set = selected_pnt
    selected_pnt_set = einops.repeat(selected_pnt_set, '... i -> ... r i', r=k)

    def body_fun(i, carry):
        selected_pnt_set, dist_min_values = carry
        new_pnt = selected_pnt_set[...,i,:]
        dist = jnp.sum((pnts - new_pnt[...,None,:])**2, axis=-1)
        dist_min_values = jnp.minimum(dist_min_values, dist)
        max_idx = jnp.argmax(dist_min_values, axis=-1)
        selected_pnt = jnp.take_along_axis(pnts, max_idx[...,None,None], axis=-2).squeeze(-2)
        selected_pnt_set = selected_pnt_set.at[...,i+1,:].set(selected_pnt)
        return (selected_pnt_set, dist_min_values)

    selected_pnt_set, dist_min_values = jax.lax.fori_loop(0, k-1, body_fun, (selected_pnt_set, dist_min_values))

    return selected_pnt_set
class ShapeEncoder(nn.Module):
    rot_configs:Sequence
    output_feat_dim:int
    nfps:int

    @nn.compact
    def __call__(self, 
                 surface_points:jnp.ndarray,
                #  pnt_obj:loutil.PointObjects, 
                 jkey=None, train=False):

        pcd_tf = surface_points
        fps_tf = FPS_padding(pcd_tf, self.nfps, jkey)
        # if pnt_obj.rel_fps is None:
        #     fps_tf = FPS_padding(pcd_tf, nfps, jkey)
        # else:
        #     fps_tf = pnt_obj.fps_tf
        npnts = pcd_tf.shape[-2]
        nfps = fps_tf.shape[-2]

        # mean_dist = jax.lax.top_k(-jnp.linalg.norm(fps_tf[...,None,:] - pcd_tf[...,None,:,:], axis=-1), k=max(int(npnts//8), 1))

        # make patch
        npnt_per_patch = max(int(1.5*(npnts/nfps)), 3)
        pairwise_dist = jnp.linalg.norm(fps_tf[...,None,:] - pcd_tf[...,None,:,:], axis=-1)
        min_dist, min_k_idx = jax.lax.top_k(-pairwise_dist, k=npnt_per_patch)

        pcd_patch = jnp.take_along_axis(pcd_tf[...,None,:,:], min_k_idx[...,None], axis=-2) # (..., NFPS, K, 3)
        pcd_patch = pcd_patch - fps_tf[...,None,:] # (..., NFPS, K, 3) # normalize pcd

        patch_scale = jnp.mean(jax.lax.top_k(-min_dist, k=max(int(npnt_per_patch//8), 1))[0], axis=-1) # (..., NFPS,)
        patch_scale = patch_scale*0.7
        # patch_scale = jnp.clip(patch_scale, 1e-5)
        pcd_patch = pcd_patch/patch_scale[...,None,None] # (..., NFPS, K, 3)
        

        # process pcd with FER-VN

        base_dim = self.output_feat_dim*4

        x = jnp.expand_dims(pcd_patch, -1)
        x = eutil.get_graph_feature(x, k=min(5, x.shape[-3]), cross=True) # (B, P, K, F, D)
        x = evl.MakeHDFeature(self.rot_configs, mlp_norm=True)(x)
        x = evl.EVLinearLeakyReLU(base_dim)(x)
        x = jnp.mean(x, -3) # (B, P, F, D)

        np = x.shape[-3]
        x = nn.Dense(base_dim, use_bias=False)(x)
        for i in range(4):
            x = evl.EVNResnetBlockFC(base_dim, base_dim)(x)
            pooled_global = einops.repeat(jnp.mean(x, -3), '... f d -> ... r f d', r=np)
            x = jnp.concatenate([x, pooled_global], -1)
        x = evl.EVNResnetBlockFC(base_dim, base_dim)(x)

        # Aggregation
        x = jnp.mean(x, axis=-3) # (..., NFPS, F, D)
        x = evl.EVNNonLinearity()(x)
        
        # Final layer
        x = nn.Dense(self.output_feat_dim, use_bias=False)(x)

        # recover scales
        x = x*patch_scale[...,None,None] # (..., NFPS, F, D)

        latnet_obj = loutil.LatentObjects(z=x, pos=jnp.mean(fps_tf, axis=-2)).set_fps_tf(fps_tf)
        return latnet_obj




class ColDecoderSceneV2(nn.Module):
    dropout:float
    rot_configs:Sequence
    multihead_no:int=4
    version:int=0
    depth:int=3
    feat_size_divider:int=1
    pos_emb_size_divider:int=1
    normalize_eps:float=0.01
    
    @nn.compact
    def __call__(self, latent_obj_A:loutil.LatentObjects, latent_obj_B:loutil.LatentObjects, 
                 pq_transform_B=None, line_segment_B=None, reduce_k=16, merge=False, jkey=None, 
                 path_check=False, broadphase_type='timeoptbf_traj', broadphase_func=None, 
                 pairwise_out=False, debug=False, train=False):
        '''
        latent_obj_AB -> (nob, 2, ) or (2, )
        pq_transform -> (nob, NP, 2, 7) or (NP, 2, 7, ) or None
        return
            (nob 1) or (nob NP 1)
        Dimenision of pq_transform
        1) latent_obj (NS, ...) pq (NP, NS, 7) -> (NP, NS)
        Dimensions of merging is determined by merge pq dims
        2) merge=True latent_obj (NS, ...) pq (NP, NS, 7) -> (NP,) (NS dim is merged)
        3) merge=True latent_obj (NS, ...) pq (NB, NP, NS, 7) -> (NB,) (NP, NS dim is merged)
        '''
        nfpsA_original = latent_obj_A.nfps
        nfpsB_original = latent_obj_B.nfps

        def idx_update(x_in, nfps, idx):
            patchwise_col_logits = -jnp.ones(x_in.shape[:-2] + (nfps,1))
            origin_shape = patchwise_col_logits.shape
            patchwise_col_logits_flat = patchwise_col_logits.reshape(-1, nfps, 1)
            x_in_flat = x_in.reshape(-1, *x_in.shape[-2:])
            idx_flat = idx.reshape(-1, idx.shape[-1])
            out_logit = jax.vmap(lambda x, idx, y: y.at[idx].set(x))(x_in_flat, idx_flat, patchwise_col_logits_flat)
            return out_logit.reshape(origin_shape)
        
        if path_check:
            path_broadphase_type = broadphase_type.split("_")[0]
            if path_broadphase_type[-2:] == 'bf':
                broadphase_func = None
            if path_broadphase_type in ['timeopt', 'timeoptbf']:
                latent_obj_A_reduced, latent_obj_B_reduced, Aidx, Bidx, \
                    t_clamped_pairwise, pairwise_line_AB_reduced, line_segment_B = broad_phase.reduce_fps_path_zoom(latent_obj_A, latent_obj_B, 
                                                                    pq_transform_B, reduce_k, self.rot_configs, jkey, broadphase_func, visualize=False)
                # if debug:
                #     return line_segment_B
            elif path_broadphase_type in ['naive', 'naivebf']:
                latent_obj_A_reduced, latent_obj_B_reduced, Aidx, Bidx, \
                    t_clamped_pairwise, pairwise_line_AB_reduced, line_segment_B = broad_phase.reduce_fps_path_pair(latent_obj_A, latent_obj_B, 
                                                                    pq_transform_B, reduce_k, self.rot_configs, broadphase_func=broadphase_func, visualize=False)
            elif path_broadphase_type == 'aabb':
                latent_obj_A_reduced, latent_obj_B_reduced, t_clamped_pairwise, pairwise_line_AB_reduced, line_segment_B = broad_phase.reduce_fps_path_aabb(latent_obj_A, latent_obj_B,
                                                                    pq_transform_B, reduce_k, self.rot_configs)
        else:
            latent_obj_A_reduced, latent_obj_B_reduced, Aidx, Bidx, t_clamped_pairwise, pairwise_line_AB_reduced, line_segment_B = broad_phase.reduce_fps(latent_obj_A, latent_obj_B, 
                                                                    line_segment_B=line_segment_B, jkey=jkey, pq_transform_B=pq_transform_B, reduce_k=reduce_k, 
                                                                    rot_configs=self.rot_configs, debug=debug, merge=merge, train=train)
        
        pairwise_pos_dif = pairwise_line_AB_reduced 
        pairwise_outer_shape = pairwise_pos_dif.shape[:-1]

        if debug:
            line_segment_B_original = line_segment_B

        z_feat_dim = latent_obj_A_reduced.z_flat.shape[-1]
        z_norm_A = jnp.linalg.norm(latent_obj_A_reduced.z_flat, axis=-1)
        z_norm_B = jnp.linalg.norm(latent_obj_B_reduced.z_flat, axis=-1)
        if z_norm_B.shape != pairwise_outer_shape:
            # add pairwise dimensions
            z_norm_B = z_norm_B[...,None,:]
            z_B = latent_obj_B_reduced.z[...,None,:,:,:]
            if line_segment_B is not None:
                line_segment_B = line_segment_B[...,None,:,:]
        else:
            z_B = latent_obj_B_reduced.z
        z_norm = jnp.maximum(z_norm_A[...,None], z_norm_B)
        # scale_eps = self.normalize_eps
        # scale = jnp.maximum(z_norm, scale_eps)
        # scale to make distance or z_norm to below 1
        pairwise_dist = jnp.linalg.norm(pairwise_pos_dif, axis=-1)
        scale = jnp.where(z_norm > 2*pairwise_dist, z_norm, 2*pairwise_dist)
        scale = jax.lax.stop_gradient(scale)

        pairwise_z_shape = pairwise_outer_shape + latent_obj_A_reduced.z.shape[-2:]
        pairwise_z = jnp.concat([jnp.broadcast_to(latent_obj_A_reduced.z[...,None,:,:], pairwise_z_shape),
                                    jnp.broadcast_to(z_B, pairwise_z_shape)], axis=-1)
        
        # swap z and pos_dif A always larger than B
        swap_mask = z_norm_A[...,None] < z_norm_B
        pairwise_pos_dif = jnp.where(swap_mask[...,None], -pairwise_pos_dif, pairwise_pos_dif)
        pairwise_z = jnp.where(swap_mask[...,None,None], jnp.concatenate([pairwise_z[...,latent_obj_A_reduced.z.shape[-1]:], pairwise_z[...,:latent_obj_A_reduced.z.shape[-1]]], axis=-1), pairwise_z)

        if line_segment_B is not None:
            if (line_segment_B.ndim != z_B.ndim-1):
                line_segment_B = line_segment_B[...,None,:]
            line_segment_B = jnp.where(swap_mask[...,None], -line_segment_B, line_segment_B)

        line_align_Rm = tutil.line2Rm(pairwise_pos_dif)
        line_align_Rm_inv = tutil.Rm_inv(line_align_Rm)
        pairwise_z = rmutil.apply_rot(pairwise_z, line_align_Rm_inv, self.rot_configs, feature_axis=-2)
        pairwise_pos_dif = jnp.einsum('...ij,...j', line_align_Rm_inv, pairwise_pos_dif)
        pairwise_z = pairwise_z / scale[...,None,None]
        pairwise_pos_dif = pairwise_pos_dif / scale[...,None]

        # add line segment feature
        if line_segment_B is not None:
            line_segment_B = jnp.einsum('...qij,...qj->...qi', line_align_Rm_inv, line_segment_B)
            line_segment_B_norm = jnp.linalg.norm(line_segment_B, axis=-1)
            line_segment_B_normalized = tutil.normalize(line_segment_B)
            line_segment_B_normalized = jnp.where(t_clamped_pairwise[...,None]>0.5, line_segment_B_normalized, -line_segment_B_normalized)
            t_clamped_pairwise = jnp.where(t_clamped_pairwise>0.5, t_clamped_pairwise, 1-t_clamped_pairwise)
            fwd_magnitude = line_segment_B_norm * (1-t_clamped_pairwise)
            bwd_magnitude = line_segment_B_norm * t_clamped_pairwise
            fwd_magnitude = fwd_magnitude.clip(0, z_norm)/scale
            bwd_magnitude = bwd_magnitude.clip(0, z_norm)/scale
            
            line_segment_dir = jnp.c_[line_segment_B_normalized * fwd_magnitude[...,None], -line_segment_B_normalized * bwd_magnitude[...,None]]
        else:
            # create dummy
            line_segment_dir = jnp.zeros((6,))
        
        original_z_feat_dim = z_feat_dim
        if self.feat_size_divider == 1:
            z_feat_dim = z_feat_dim//self.feat_size_divider
        
        line_seg_feat = line_segment_dir
        for i in range(self.depth-1):
            line_seg_feat = nn.Dense(z_feat_dim)(line_seg_feat)
            line_seg_feat = nn.gelu(line_seg_feat)

        pairwise_z_A, pairwise_z_B = pairwise_z[...,:latent_obj_A_reduced.z.shape[-1]], pairwise_z[...,latent_obj_A_reduced.z.shape[-1]:]
        if line_segment_B is not None:
            pairwise_z_B_feat = jnp.concat([einops.rearrange(pairwise_z_B, '... j k -> ... (j k)'), line_seg_feat], axis=-1)
        else:
            pairwise_z_B_feat = jnp.zeros((original_z_feat_dim+z_feat_dim,))

        for i in range(self.depth-1):
            pairwise_z_B_feat = nn.Dense(z_feat_dim)(pairwise_z_B_feat)
            pairwise_z_B_feat = nn.gelu(pairwise_z_B_feat)
        
        if self.feat_size_divider != 1:
            pairwise_z_B_feat = nn.Dense(original_z_feat_dim)(pairwise_z_B_feat)

        if line_segment_B is not None:
            pairwise_z_B_original = pairwise_z_B
            pairwise_z_B = pairwise_z_B_feat
            pairwise_z_B = einops.rearrange(pairwise_z_B, '... (j k) -> ... j k', j=latent_obj_B_reduced.z.shape[-2])
            pairwise_z_B = jnp.where(line_segment_B_norm[...,None, None] < 1e-6, pairwise_z_B_original, pairwise_z_B)

        pairwise_z = jnp.concatenate([pairwise_z_A, pairwise_z_B], axis=-1)

        pairwise_pos_feat = pairwise_pos_dif[...,2:]

        # num_encoding_dims = 16
        # emb = jnp.power(2, jnp.arange(0, num_encoding_dims)-4)
        # pairwise_pos_feat = pairwise_pos_feat[..., None, :] * emb[..., :, None]
        # pairwise_pos_feat = jnp.concat([jnp.sin(pairwise_pos_feat), jnp.cos(pairwise_pos_feat)], axis=-2)
        # pairwise_pos_feat = pairwise_pos_feat.reshape(*pairwise_pos_feat.shape[:-2], -1)

        for i in range(self.depth-1):
            pairwise_pos_feat = nn.Dense(z_feat_dim//self.pos_emb_size_divider)(pairwise_pos_feat)
            pairwise_pos_feat = nn.gelu(pairwise_pos_feat)
        
        pairwise_z_flat = einops.rearrange(pairwise_z, '... j k -> ... (j k)')
        pairwise_feat = jnp.concat([pairwise_z_flat, pairwise_pos_feat], axis=-1)

        # process pairwise features with resnet
        net_dim = z_feat_dim * self.multihead_no
        for i in range(self.depth):
            pairwise_feat = nn.Dense(net_dim)(pairwise_feat)
            pairwise_feat = nn.gelu(pairwise_feat)
            if i==0:
                skip = pairwise_feat
        pairwise_feat += skip
        pairwise_feat = nn.Dense(1)(pairwise_feat)
        if train and self.dropout < 1.0:
            jkey, subkey = jax.random.split(jkey)
            pairwise_feat = jnp.where(jax.random.uniform(subkey, shape=pairwise_feat.shape) <= self.dropout, pairwise_feat, -1e8)

        # if debug:
        #     visualize_collision(latent_obj_A_reduced, latent_obj_B_reduced, line_segment_B_original, t_clamped_pairwise, latent_obj_A, latent_obj_B, pq_transform_B, pairwise_feat)

        if path_check or pairwise_out:
            return pairwise_feat

        col_logit_global = jnp.max(pairwise_feat, axis=(-2,-3))
        col_logit_local_A = jnp.max(pairwise_feat, axis=-2)
        col_logit_local_B = jnp.max(pairwise_feat, axis=-3)

        patchwise_col_logits_A = idx_update(col_logit_local_A, nfpsA_original, Aidx)
        patchwise_col_logits_B = idx_update(col_logit_local_B, nfpsB_original, Bidx)

        return col_logit_global, patchwise_col_logits_A, patchwise_col_logits_B



class ColDecoderGlobal(nn.Module):
    # args:typing.NamedTuple
    base_dim:int
    droptout:float
    rot_configs:typing.Sequence

    @nn.compact
    def __call__(self, obj_A:loutil.LatentObjects, obj_B:loutil.LatentObjects, 
                 pq_transform_B=None, line_segment_B=None,
                 jkey=None, reduce_k=None, train=False):
        '''
        obj_A - outer_shape: (B, )
        obj_B - outer_shape: (B, )
        '''
        
        # z_pair, cen_pair, pos_pair = emb_pair
        obj_pair = obj_A.stack(obj_B, axis=-1)
        
        # Center
        obj_pair_tf = obj_pair.translate(-jnp.mean(obj_pair.pos, axis=-2, keepdims=True))

        # normalize scales 
        obj_pair_tf = obj_pair_tf.apply_scale(1/jnp.mean(obj_pair.len, axis=-1, keepdims=True), center=jnp.zeros((3,),jnp.float32))

        # Align pair to z-axis
        dc_tf_norm = jnp.linalg.norm(obj_pair_tf.fps_tf, axis=-1) # (B 2 D)
        dc_tf_max = jnp.take_along_axis(obj_pair_tf.fps_tf, jnp.argmax(dc_tf_norm, axis=-1)[...,None,None], axis=-2) # (B 2 1 3)
        dc_tf_max_norm = jnp.linalg.norm(dc_tf_max, axis=-1)
        z_axis = jnp.where(dc_tf_max_norm[...,0:1,:] > dc_tf_max_norm[...,1:2,:], obj_pair_tf.pos[...,0:1,:], obj_pair_tf.pos[...,1:2,:]) # (B 1 3)
        y_axis = jnp.where(dc_tf_max_norm[...,0:1,:] > dc_tf_max_norm[...,1:2,:], dc_tf_max[...,0,:,:], dc_tf_max[...,1,:,:])  # (B 1 3)
        y_axis = jnp.where(jnp.abs(jnp.sum(y_axis*z_axis, -1, keepdims=True))<1e-4, 
            jnp.sum(obj_pair_tf.z[...,0,:3,0], axis=-2, keepdims=True), y_axis)
        # Align pair to z-axis
        qoff = tutil.line2q(z_axis, yaxis=y_axis)
        # qoff = tutil.line2q(obj_pair_tf.pos[...,0:1,:], yaxis=jnp.array([1,0,0]))
        qoffinv = tutil.qinv(qoff) ## (... 1 ND 3 3)
        obj_pair_tf = obj_pair_tf.apply_pq_z(jnp.zeros((3,),jnp.float32), qoffinv, self.rot_configs)
        # z-only
        mu_dc_AB = obj_pair_tf.pos[...,-1:] 
        for i in range(2):
            mu_dc_AB = nn.Dense(self.base_dim//2)(mu_dc_AB) # (B, 2, 16)
            mu_dc_AB = jnp.sin(mu_dc_AB)

        z_dc_AB = obj_pair_tf.z # (B, 2, C, F, D)

        base_dim_factor = 4
        # base_dim_factor = 2
        # aligned z is rot-invariant -> now flatten ok. flatten is better than invariant layer.
        z_dc_AB = einops.rearrange(z_dc_AB, '... f d -> ... (f d)') # (B, 2, C, F*D)
        z_dc_AB = jnp.squeeze(z_dc_AB, axis=-2) # (B, 2, F*D)

        x = jnp.concatenate([z_dc_AB, mu_dc_AB], axis=-1)
        for i in range(3):
            x = nn.Dense(self.base_dim*base_dim_factor)(x)
            x = nn.relu(x)
            if i==0:
                skip = x
        x += skip
        x = nn.Dense(self.base_dim*base_dim_factor)(x)
        x = nn.relu(x)

        # Inter object max pool
        x = jnp.max(x, axis=-2)
        x = nn.Dense(x.shape[-1])(x)
        x = nn.relu(x)
        x = nn.Dense(1)(x)
        # x = jnp.squeeze(x, axis=-1)
        # x = nn.sigmoid(x)
        return x
    


def aggregate_cost(cost, axes, reduce_ops=jnp.sum):
    '''
    sum over positive values
    if false, max values
    '''

    positive_cost = jnp.maximum(cost, 0)
    # positive_cost_mean = jnp.sum(positive_cost, axis=axes) / jnp.maximum(jnp.sum(positive_cost > 0, axis=axes), 1)
    positive_cost_mean = reduce_ops(positive_cost, axis=axes)
    max_cost = jnp.max(cost, axis=axes)
    res_cost = jnp.where(max_cost > 0, positive_cost_mean, max_cost)
    return res_cost


def pairwise_collision_check(models:Models, oriCORNs_A:loutil.LatentObjects, oriCORNs_B:loutil.LatentObjects, posquat_A, posquat_B, aggregate=True):
    '''
    oriCORNs_A: (NA, ...)
    oriCORNs_B: (NB, ...)
    posquat_A: (NQ, NA, 7)
    posquat_B: (NQ, NB, 7)
    return NQ
    '''
    # oriCORNs_B = oriCORNs_B.squeeze_outer_shape(axis=0)
    # broad cast outer shape
    if posquat_A.ndim == 2:
        posquat_A = posquat_A[None]
    if posquat_B.ndim == 2:
        posquat_B = posquat_B[None]

    oriCORN_scene_batch_shape = jnp.broadcast_shapes(oriCORNs_A.shape[:-1], oriCORNs_B.shape[:-1])
    posquat_scene_batch_shape = jnp.broadcast_shapes(posquat_A.shape[:-2], posquat_B.shape[:-2])

    posquat_query_shape = posquat_scene_batch_shape[len(oriCORN_scene_batch_shape):]

    oriCORNs_A = oriCORNs_A.broadcast_outershape(oriCORN_scene_batch_shape + (oriCORNs_A.nobj,))
    oriCORNs_B = oriCORNs_B.broadcast_outershape(oriCORN_scene_batch_shape + (oriCORNs_B.nobj,))

    posquat_A = jnp.broadcast_to(posquat_A, posquat_scene_batch_shape + posquat_A.shape[-2:])
    posquat_B = jnp.broadcast_to(posquat_B, posquat_scene_batch_shape + posquat_B.shape[-2:])

    # (NB NR NS) collision pairs
    # first, make collision pairs
    num_A = oriCORNs_A.nobj
    num_B = oriCORNs_B.nobj
    oriCORN_A_pw = oriCORNs_A.extend_and_repeat_outer_shape(num_B, -1)
    oriCORN_B_pw = oriCORNs_B.extend_and_repeat_outer_shape(num_A, -2)
    col_query_oriCORN_pw = oriCORN_A_pw.stack(oriCORN_B_pw, axis=-1) # (NR NS 2 ...)

    col_query_oriCORN_pw = col_query_oriCORN_pw.reshape_outer_shape((-1, 2))

    robot_pqc_pw = posquat_A[...,None,:].repeat(num_B, axis=-2) # (NB NR NS 7)
    scene_pqc_pw = posquat_B[...,None,:,:].repeat(num_A, axis=-3) # (NB NR NS 7)

    pqc_pw = jnp.stack([robot_pqc_pw, scene_pqc_pw], axis=-2) # (NB NR NS 2 7)

    nquery = np.prod(posquat_query_shape)
    pqc_pw = pqc_pw.reshape(-1, nquery, num_A*num_B, *pqc_pw.shape[-2:])
    pqc_pw = jnp.moveaxis(pqc_pw, 0, 1) # (NP, 2, 7)
    pqc_pw = pqc_pw.reshape(nquery, -1, *pqc_pw.shape[-2:])

    # latent_obj_AB -> (nob, 2, ) or (2, )
    # pq_transform -> (nob, NP, 2, 7) or (NP, 2, 7, ) or None
    col_cost = models.apply('col_decoder', col_query_oriCORN_pw[:,0], col_query_oriCORN_pw[:,1], pqc_pw[...,0,:], pqc_pw[...,1,:])[0] # (NR, NS, NB, 1)
    col_cost = col_cost.reshape(*posquat_scene_batch_shape, num_A, num_B)
    # col_res_binary = col_cost > 0 # True for collision
    # state_invalid_mask = jnp.any(col_res_binary, axis=(-2,-1))

    if aggregate:
        # more clever way to aggregate this values?
        state_col_cost = aggregate_cost(col_cost, (-2,-1))
        return state_col_cost
    else:
        return col_cost
    


def sdf_collision_check(models:Models, oriCORNs_A:loutil.LatentObjects, oriCORNs_B:loutil.LatentObjects):
    '''
    1) evaluate sdf < 0 for fps points
    2) are there within points? -> collision
    '''

    # broad cast outer shape
    oriCORN_scene_batch_shape = jnp.broadcast_shapes(oriCORNs_A.pos.shape[:-1], oriCORNs_B.pos.shape[:-1])
    oriCORNs_A = oriCORNs_A.reshape_outer_shape((-1,))
    oriCORNs_B = oriCORNs_B.reshape_outer_shape((-1,))
    fps_tf_A = oriCORNs_A.fps_tf
    fps_tf_B = oriCORNs_B.fps_tf

    B_sdf = models.apply('sdf_decoder', oriCORNs_A, fps_tf_B)
    A_sdf = models.apply('sdf_decoder', oriCORNs_B, fps_tf_A)

    state_col_cost = aggregate_cost(-jnp.concatenate([B_sdf, A_sdf], axis=-2), (-2,))
    # col_cost = jnp.minimum(jnp.min(B_sdf, axis=-2), jnp.min(A_sdf, axis=-2))
    state_col_cost = state_col_cost.reshape(oriCORN_scene_batch_shape)
    return state_col_cost
