import jax.numpy as jnp
from pathlib import Path
import jax
import jax.debug as jdb
try:
    # from ott.geometry import pointcloud, costs
    # from ott.solvers import linear
    ott_on = True
except:
    ott_on = False
    print("ott not found")
import typing
import optax
from functools import partial
import sys
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
import einops
import flax.linen as nn
import time
import wandb
import os

BASEDIR = Path(__file__).parent.parent
if str(BASEDIR) not in sys.path:
    sys.path.insert(0, str(BASEDIR))


import util.diffusion_util as dfutil
import util.render_util as rutil
import util.model_util as mutil
import util.transform_util as tutil
import util.camera_util as cutil
import util.structs as structs
import util.train_util as trutil
import util.latent_obj_util as loutil
import util.bp_matching_util as bputil

def debug_callback(inputs):
    print(inputs, np.sum(inputs[0]), np.sum(inputs[1]))

def update_tree(new_tree, old_tree):
    """Recursively update new_tree with values from old_tree where names and shapes match."""
    # If both are dictionaries, update key-by-key.
    if isinstance(new_tree, dict) and isinstance(old_tree, dict):
        updated = {}
        for key, new_val in new_tree.items():
            if key in old_tree:
                updated[key] = update_tree(new_val, old_tree[key])
            else:
                updated[key] = new_val
        return updated
    # If both are lists or tuples, update element-wise.
    elif isinstance(new_tree, (list, tuple)) and isinstance(old_tree, (list, tuple)):
        if len(new_tree) != len(old_tree):
            return new_tree
        updated = [update_tree(n, o) for n, o in zip(new_tree, old_tree)]
        if type(new_tree)==tuple:
            return type(new_tree)(updated)
        else:
            return type(new_tree)(*updated)
    else:
        # At a leaf: if both objects have a shape attribute, check if they match.
        try:
            if new_tree.shape == old_tree.shape:
                return old_tree
        except AttributeError:
            # If no shape attribute, just return new_tree.
            pass
        return new_tree



def visualize_dataset(train_loader, models:mutil.Models, args):
    # ## data visualization
    for ds_batch in train_loader:
        rgbs = ds_batch['rgbs']

        gt_obj = loutil.LatentObjects().init_obj_info(ds_batch['obj_info'], models.mesh_aligned_canonical_obj, models.rot_configs)

        nd = gt_obj.nfps
        nf = gt_obj.nf
        nz = gt_obj.nz
        nh = gt_obj.h.shape[-1]
        base_shape = (nd, nf, nz)

        vis_idx = 0
        intrinsic = ds_batch["cam_info"]['cam_intrinsics'].astype(np.float32)
        cam_posquat = ds_batch["cam_info"]['cam_posquats'].astype(np.float32)
        vis_intrinsic = intrinsic[vis_idx,0]
        vis_cam_posquat = cam_posquat[vis_idx,0]
        pixel_size = (int(vis_intrinsic[1]), int(vis_intrinsic[0]))
        vis_cam_pos, vis_cam_quat = vis_cam_posquat[:3], vis_cam_posquat[3:]
        sdf_func = partial(rutil.scene_sdf, models=models)
        latent_render_func = lambda: jax.jit(partial(rutil.cvx_render_scene, models=models, sdf=sdf_func, pixel_size=pixel_size, 
                                                    intrinsic=vis_intrinsic, camera_pos=vis_cam_pos, camera_quat=vis_cam_quat, seg_out=True))
        # latent_render_func = lambda: partial(rutil.cvx_render_scene, models=models, sdf=sdf_func, pixel_size=pixel_size, 
        #                                              intrinsic=vis_intrinsic, camera_pos=vis_cam_pos, camera_quat=vis_cam_quat, seg_out=True)
        latent_render_func_dict = {}
        for i in range(1, 10):
            latent_render_func_dict[i] = latent_render_func()
        latent_render_func_dict[args.ds_obj_no] = latent_render_func()
        latent_render_func_dict[args.nparticles] = latent_render_func()
        gt_obj_sq = gt_obj[vis_idx]
        gt_obj_sq = gt_obj_sq[gt_obj_sq.obj_valid_mask]

        # from visualize import visualize_obj_pair
        # dec_for_vis = jax.jit(partial(models.apply, 'sdf_decoder'))
        # visualize_obj_pair(None, gt_obj_sq, dec_for_vis, level=0.25, qp_bound=0.2)

        rgb_pred, seg_pred = latent_render_func_dict[gt_obj_sq.nobj](gt_obj_sq)

        plt.figure()
        for i in range(5):
            plt.subplot(3,3,i+1)
            plt.imshow(rgbs[vis_idx,i])
            plt.axis('off')
        plt.subplot(3,3,6)
        plt.imshow(rgb_pred)
        plt.axis('off')
        plt.subplot(3,3,7)
        plt.imshow(seg_pred)
        plt.axis('off')
        plt.show()

        # plt.figure()
        # for i in range(rgbs.shape[0]):
        #     for j in range(rgbs.shape[1]):
        #         plt.subplot(rgbs.shape[0], rgbs.shape[1], i*rgbs.shape[1]+j+1)
        #         plt.imshow(rgbs[i][j])
        # plt.show()

        break


# @partial(jax.jit, static_argnums=(0,))
def model_apply_param_jit(models, params, x, cond, t, cond_mask, jk, previous_emb):
    models_ = models.set_params(params, update_trainable_only=True)
    cond = models_.apply('spatial_PE', cond, train=False)
    return models_.apply('denoiser', x, cond, t, cond_mask, previous_emb=previous_emb, rngs={'dropout':jk})

   
def eval_dif_model(models:mutil.Models, eval_batch:dict, params, jkey, itr, logs_dir, tb_writer, args, wandb_run):
    eval_batch = jax.tree_util.tree_map(lambda x: x[:1], eval_batch)
    _, jkey = jax.random.split(jkey)
    models_ = models.set_params(params, update_trainable_only=True)
    models = models_

    obj_info = eval_batch["obj_info"]
    ns = eval_batch["cam_info"]['cam_intrinsics'].shape[0]

    gt_obj = loutil.LatentObjects().init_obj_info(obj_info, models.mesh_aligned_canonical_obj, models.rot_configs)
    if not args.single_ds:
        pos_randomization = jax.random.uniform(jkey, (ns, 1, 3), minval=-0.1, maxval=0.1)
        _, jkey = jax.random.split(jkey)
        z_rotation = jax.random.uniform(jkey, (ns, 1, 1), minval=-np.pi, maxval=np.pi)
        _, jkey = jax.random.split(jkey)
        z_rotation = jnp.c_[jnp.zeros_like(z_rotation), jnp.zeros_like(z_rotation), z_rotation]
        z_rotation = tutil.aa2q(z_rotation)
        gt_obj = gt_obj.apply_pq_z(pos_randomization, z_rotation, models.rot_configs)

    nd = gt_obj.nfps
    nf = gt_obj.nf
    nz = gt_obj.nz
    nh = gt_obj.h.shape[-1]
    latent_shape = (nd, nf, nz)

    vis_idx = 0
    vp_idx = np.random.randint(0, ns)

    intrinsic = eval_batch["cam_info"]['cam_intrinsics'].astype(np.float32)
    cam_posquat = eval_batch["cam_info"]['cam_posquats'].astype(np.float32)
    if not args.single_ds:
        cam_posquat = tutil.pq_multi(pos_randomization, z_rotation, cam_posquat[...,:3], cam_posquat[...,3:])
        cam_posquat = jnp.concatenate(cam_posquat, axis=-1)
    vis_intrinsic = intrinsic[vis_idx,vp_idx]
    vis_cam_posquat = cam_posquat[vis_idx,vp_idx]
    pixel_size = (int(vis_intrinsic[1]), int(vis_intrinsic[0]))
    vis_cam_pos, vis_cam_quat = vis_cam_posquat[:3], vis_cam_posquat[3:]
    sdf_func = partial(rutil.scene_sdf, models=models_)
    latent_render_func = lambda: jax.jit(partial(rutil.cvx_render_scene, models=models_, sdf=sdf_func, pixel_size=pixel_size, 
                                                    intrinsic=vis_intrinsic, camera_pos=vis_cam_pos, camera_quat=vis_cam_quat, seg_out=True))
    latent_render_func_dict = {}
    for i in range(1, args.nparticles+1):
        latent_render_func_dict[i] = latent_render_func()
    gt_obj_sq = gt_obj[vis_idx]
    # object to pixel
    def render_img(obj:loutil.LatentObjects):
        if obj.nobj != 1:
            obj = obj[obj.obj_valid_mask]
        if obj.outer_shape[0] == 0:
            return np.zeros((*pixel_size, 3)), np.zeros(pixel_size)
        
        # rgb_, seg_ = np.zeros((*pixel_size, 3)), np.zeros(pixel_size)
        
        # temporaly turn off rendering
        if args.fps_only:
            rgb_, seg_ = np.zeros((*pixel_size, 3)), np.zeros(pixel_size)
        else:
            rgb_, seg_ = latent_render_func_dict[obj.outer_shape[0]](obj)
        
        pixel_coord, out_ = cutil.global_pnts_to_pixel(vis_intrinsic, vis_cam_posquat, obj.fps_tf) # (... NR)
        
        # add more pixel_coord
        if args.fps_only:
            # pixel_coord = pixel_coord[...,None,:] + np.array([[[0,0],[0,1],[1,0],[1,1],[0,-1],[-1,0],[-1,-1],[1,-1],[-1,1], [2,0],[0,2],[-2,0],[0,-2]]])
            # pixel_coord = pixel_coord[...,None,:] + np.array([[[0,0],[0,1],[1,0],[1,1],[0,-1],[-1,0],[-1,-1],[1,-1],[-1,1]]])
            pixel_coord = pixel_coord[...,None,:] + np.array([[[0,0],[0,1],[1,0],[0,-1],[-1,0]]])
        else:
            pixel_coord = pixel_coord[...,None,:] + np.array([[[0,0]]])

        pixel_coord = np.array(pixel_coord).astype(np.int32).clip(0, vis_intrinsic[:2][::-1]-1).astype(np.int32)
        pixel_coord = np.where(np.isnan(pixel_coord), 0, pixel_coord)
        pixel_coord = np.where(pixel_coord<0, 0, pixel_coord)
        pixel_coord = np.where(pixel_coord>10000, 0, pixel_coord)

        pixel_coord = pixel_coord.reshape(pixel_coord.shape[0], -1, 2)

        rgb_ = np.array(rgb_)
        for i in range(pixel_coord.shape[0]):
            rgb_[pixel_coord[i,:,0], pixel_coord[i,:,1]] = np.ones(3)
        return rgb_, seg_

    if eval_batch["rgbs"] is None:
        cond = None
    else:
        if args.single_ds:
            cam_intrinsic_resized = cutil.resize_intrinsic(intrinsic, eval_batch["rgbs"].shape[-3:-1], (4,4))
            cond = structs.ImgFeatures(intrinsic=cam_intrinsic_resized, cam_posquat=cam_posquat, 
                                       img_feat=jnp.zeros(eval_batch["rgbs"].shape[:-3] + (4,4,512,)), rgb=eval_batch["rgbs"])
        else:
            cond = models_.apply('img_encoder', eval_batch["rgbs"], cam_posquat, intrinsic, dino_feat_out=args.train_seg_only==1)

        if args.train_seg:
            seg_pred = models_.apply('seg_predictor', cond)
            seg_pred = nn.sigmoid(seg_pred)
            seg_label = eval_batch["seg"].astype(jnp.float32)
            seg_iou = jnp.sum(seg_pred.squeeze(-1) * seg_label)/(jnp.sum(seg_pred.squeeze(-1) + seg_label - seg_pred.squeeze(-1) * seg_label) + 1e-6)
            seg_pred_vis = seg_pred[vis_idx, vp_idx]
        else:
            seg_iou = 0
    
    model_apply_jit = models_.get_model_apply_func(spatial_PE=True, jit=True)
    # model_apply_jit = jax.jit(partial(model_apply_param_jit, models_))
    # model_apply_jit = partial(model_apply_jit, params)

    rgb_gt = eval_batch["rgbs"][vis_idx][vp_idx] if eval_batch["rgbs"] is not None else np.zeros((*pixel_size, 3))
    
    def rgb_mix(rgb, pred, seg=None):
        alpha = 0.65
        if seg is not None:
            pred_ = np.where(seg[...,None]>=0, pred, 0)
        else:
            pred_ = pred
        return (1-alpha)*rgb/255. + pred_*alpha
    
    if args.train_seg_only:
        fig_ = plt.figure()
        plt.imshow(rgb_mix(rgb_gt, seg_pred_vis))
        fig_.canvas.draw()
        rgb = np.array(fig_.canvas.renderer._renderer)
        plt.close()

        eval_metric_dict = {
            'eval_seg_iou': seg_iou,
        }

        return rgb, eval_metric_dict


    # sampling
    _, jkey = jax.random.split(jkey)
    dif_num_timesteps = [5, 10, 25, 50]
    x_inf_list = []
    time_list = []
    for max_time_steps, update_cam_pq in zip(dif_num_timesteps, [False, False, False, False]):
        time_start = time.time()
        x_pred, aux_mask = dfutil.euler_sampler_obj_single_fori(models_, (ns, args.nparticles, nh), model_apply_jit, cond, jkey, conf_filter_out=False,
                                                max_time_steps=max_time_steps, update_cam_pq=update_cam_pq)
        x_pred = jax.block_until_ready(x_pred)
        time_end = time.time()
        x_inf_list.append(x_pred)
        time_list.append(time_end - time_start)
    
    x_pred_list = aux_mask['x_pred_list']
    x_dif_list = aux_mask['x_diffusion_list']

    x_inf_list_vis, x_dif_list_vis, x_pred_list_vis = \
        jax.tree_util.tree_map(lambda x: x[vis_idx], (x_inf_list, x_dif_list, x_pred_list))
    
    if itr%(args.save_interval*2) == 0:
        fps = 15 if args.dm_type != 'regression' else 3
        video_logs_dir = os.path.join(logs_dir, 'video')
        os.makedirs(video_logs_dir, exist_ok=True)
        video_arr = []
        with rutil.VideoWriter(os.path.join(video_logs_dir, f'dif_{itr}.mp4'), fps=fps) as vid:
            for t in range(x_dif_list_vis.shape[0]):
                video_arr.append(rgb_mix(rgb_gt, *render_img(x_dif_list_vis[t])))
                vid(video_arr[-1])
        tb_writer.add_video(
            tag='dif_video',
            vid_tensor=np.stack(video_arr)[None].transpose(0,1,4,2,3),
            global_step=itr, fps=fps)
        
        if not args.debug:
            frames = (np.stack(video_arr)*255).astype(np.uint8).transpose(0,3,1,2)

        video_arr = []
        with rutil.VideoWriter(os.path.join(video_logs_dir, f'pred_{itr}.mp4'), fps=fps) as vid:
            for t in range(x_pred_list_vis.shape[0]):
                video_arr.append(rgb_mix(rgb_gt, *render_img(x_pred_list_vis[t])))
                vid(video_arr[-1])
        tb_writer.add_video(
            tag='pred_progress',
            vid_tensor=np.stack(video_arr)[None].transpose(0,1,4,2,3),
            global_step=itr, fps=fps)
        
        if not args.debug:
            frames2 = (np.stack(video_arr)*255).astype(np.uint8).transpose(0,3,1,2)
            wandb_run.log({"dif_video": wandb.Video(frames, fps=fps), "pred_video": wandb.Video(frames2, fps=fps)}, step=itr)

    jkey, subkey = jax.random.split(jkey)
    gt_obj_pad = gt_obj.valid_obj_padding(args.nparticles, subkey)
    perturbe_recover_t_list = [0.1, 0.5, 1.0]
    ptb_rec_obj_list = []
    for t_ in perturbe_recover_t_list:
        ptb_rec_obj_list.append(dfutil.perturb_recover_obj(models_, gt_obj_pad, model_apply_jit, cond, jnp.array(t_), jkey))
        _, jkey = jax.random.split(jkey)
    ptb_rec_obj_list_vis = jax.tree_util.tree_map(lambda x: x[vis_idx], ptb_rec_obj_list)
    vis_obj_list = (gt_obj_sq.drop_gt_info(), *x_inf_list_vis, *ptb_rec_obj_list_vis)

    rgb_list = []
    for vo in vis_obj_list:
        rgb_list.append(render_img(vo)[0])

    fig_ = plt.figure()
    plt.subplot(3,3,1)
    # if args.train_seg:
    #     plt.imshow(rgb_mix(rgb_gt, seg_pred_vis))
    # else:
    plt.imshow(rgb_gt)
    plt.axis('off')
    for i in range(8):
        plt.subplot(3,3,i+2)
        # plt.imshow(rgb_mix(rgb_gt, rgb_list[i]))
        plt.imshow(rgb_list[i])
        # plt.title(f'{obs_values[i]:0.3f}')
        plt.axis('off')
    # plt.close()
            
    fig_.canvas.draw()
    rgb = np.array(fig_.canvas.renderer._renderer)
    plt.close()

    # feature renderer imgs
    if args.render_loss_weight!=0:
        output_pixel_size = (pixel_size[0]//2, pixel_size[1]//2)
        cond_frender:structs.ImgFeatures = models_.apply('spatial_PE', cond, train=False)

        env_oriCORNs_target:loutil.LatentObjects = loutil.LatentObjects().init_obj_info(eval_batch["env_info"], jax.lax.stop_gradient(models_.mesh_aligned_canonical_obj), models_.rot_configs)
        env_oriCORNs_target = jax.lax.stop_gradient(env_oriCORNs_target[:,-1:]) # only sink
        oriCORNs_target = gt_obj.concat(env_oriCORNs_target, axis=1)
        oriCORN_rendered, _ = models_.apply('feature_renderer', oriCORNs_target, cond_frender, cond_frender.cam_posquat[...,-1,:], 
                                            cond_frender.intrinsic[...,-1,:], output_pixel_size, train=False)
        rgb_nv_target = eval_batch["rgbs"][...,-1,:,:,:]
        if rgb_nv_target.dtype == jnp.uint8:
            rgb_nv_target = rgb_nv_target.astype(jnp.float32)/255.
        rgb_nv_target = jax.image.resize(rgb_nv_target, rgb_nv_target.shape[:-3] + output_pixel_size + (3,), method='nearest')
        oriCORN_render_vis = jnp.stack([oriCORN_rendered[:1], rgb_nv_target[:1]], axis=0)
        oriCORN_render_vis = einops.rearrange(oriCORN_render_vis, 'B N C H W -> (B C) (N H) W')
        oriCORN_render_vis = np.array(oriCORN_render_vis)
        rgb_vis = oriCORN_render_vis
    else:
        rgb_vis = rgb

    cal_xpred_dif_jit = jax.jit(partial(trutil.cal_xpred_dif, args=args, loss_type='l2'))
    name_list = [f'inf_{dnt}' for dnt in dif_num_timesteps] + [f'ptb_{t}' for t in perturbe_recover_t_list]
    pred_aux_info_list = [cal_xpred_dif_jit(pred, None, None, gt_obj)[2] for pred in x_inf_list + ptb_rec_obj_list]

    pwloss_pred_dict = {name+'_pwloss':np.mean(pred_aux_info_list[i]['chloss']) for i, name in enumerate(name_list)}
    pwfps_pred_dict = {name+'_pwfps':np.mean(pred_aux_info_list[i]['ch_fps']) for i, name in enumerate(name_list)}
    pwz_pred_dict = {name+'_pwz':np.mean(pred_aux_info_list[i]['ch_z']) for i, name in enumerate(name_list)}

    eval_metric_dict = {
        'eval_seg_iou': seg_iou,
        'eval_inf_5_time':time_list[0],
        'eval_inf_50_time':time_list[-1],
        **models.cal_statics(),
        **pwloss_pred_dict,
        **pwfps_pred_dict,
        **pwz_pred_dict
        }

    return rgb_vis, eval_metric_dict




def eval_models_metric(models:mutil.Models, eval_batch:dict, params, jkey, args):
    _, jkey = jax.random.split(jkey)
    models_ = models.set_params(params, update_trainable_only=True)
    models = models_

    obj_info = eval_batch["obj_info"]
    ns = eval_batch["cam_info"]['cam_intrinsics'].shape[0]

    gt_obj = loutil.LatentObjects().init_obj_info(obj_info, models.mesh_aligned_canonical_obj, models.rot_configs)
    if not args.single_ds:
        pos_randomization = jax.random.uniform(jkey, (ns, 1, 3), minval=-0.1, maxval=0.1)
        _, jkey = jax.random.split(jkey)
        z_rotation = jax.random.uniform(jkey, (ns, 1, 1), minval=-np.pi, maxval=np.pi)
        _, jkey = jax.random.split(jkey)
        z_rotation = jnp.c_[jnp.zeros_like(z_rotation), jnp.zeros_like(z_rotation), z_rotation]
        z_rotation = tutil.aa2q(z_rotation)
        gt_obj = gt_obj.apply_pq_z(pos_randomization, z_rotation, models.rot_configs)

    nd = gt_obj.nfps
    nf = gt_obj.nf
    nz = gt_obj.nz
    nh = gt_obj.h.shape[-1]

    intrinsic = eval_batch["cam_info"]['cam_intrinsics'].astype(np.float32)
    cam_posquat = eval_batch["cam_info"]['cam_posquats'].astype(np.float32)
    
    if not args.single_ds:
        cam_posquat = tutil.pq_multi(pos_randomization, z_rotation, cam_posquat[...,:3], cam_posquat[...,3:])
        cam_posquat = jnp.concatenate(cam_posquat, axis=-1)

    cond = models_.apply('img_encoder', eval_batch["rgbs"], cam_posquat, intrinsic, dino_feat_out=args.train_seg_only==1)
    
    model_apply_jit = partial(model_apply_param_jit, models_, params)

    # sampling
    _, jkey = jax.random.split(jkey)
    dif_num_timesteps = [5, 20, 50]
    x_inf_list = []
    for max_time_steps, update_cam_pq in zip(dif_num_timesteps, [False, False, False, False]):
        x_pred, aux_mask = dfutil.euler_sampler_obj_single_fori(models_, (ns, args.nparticles, nh), model_apply_jit, cond, jkey, conf_filter_out=False,
                                                max_time_steps=max_time_steps, update_cam_pq=update_cam_pq)
        x_pred = jax.block_until_ready(x_pred)
        x_inf_list.append(x_pred)
    
    x_pred_list = aux_mask['x_pred_list']
    x_dif_list = aux_mask['x_diffusion_list']
    
    jkey, subkey = jax.random.split(jkey)
    gt_obj_pad = gt_obj.valid_obj_padding(args.nparticles, subkey)
    perturbe_recover_t_list = [0.1, 0.5, 1.0]
    ptb_rec_obj_list = []
    for t_ in perturbe_recover_t_list:
        ptb_rec_obj_list.append(dfutil.perturb_recover_obj(models_, gt_obj_pad, model_apply_jit, cond, jnp.array(t_), jkey))
        _, jkey = jax.random.split(jkey)

    cal_xpred_dif_jit = jax.jit(partial(trutil.cal_xpred_dif, args=args, loss_type='l2'))
    name_list = [f'inf_{dnt}' for dnt in dif_num_timesteps] + [f'ptb_{t}' for t in perturbe_recover_t_list]
    pred_aux_info_list = [cal_xpred_dif_jit(pred, None, None, gt_obj)[2] for pred in x_inf_list + ptb_rec_obj_list]

    pwloss_pred_dict = {name+'_pwloss':jnp.mean(pred_aux_info_list[i]['chloss']) for i, name in enumerate(name_list)}
    pwfps_pred_dict = {name+'_pwfps':jnp.mean(pred_aux_info_list[i]['ch_fps']) for i, name in enumerate(name_list)}
    pwz_pred_dict = {name+'_pwz':jnp.mean(pred_aux_info_list[i]['ch_z']) for i, name in enumerate(name_list)}

    eval_metric_dict = {
        **pwloss_pred_dict,
        **pwfps_pred_dict,
        **pwz_pred_dict
        }

    pred_outputs = {
        'x_gt': gt_obj,
        'x_inf_list': x_inf_list,
        'x_dif_list': x_dif_list,
        'x_pred_list': x_pred_list,
        'ptb_rec_obj_list': ptb_rec_obj_list
    }

    observations = {
        'rgbs': eval_batch["rgbs"],
        'seg': eval_batch["seg"],
        'cam_posquat': cam_posquat,
        'intrinsic': intrinsic,
    }

    return eval_metric_dict, pred_outputs, observations




def eval_visualization_log(models:mutil.Models, params, eval_pred_outputs, eval_observations, jkey, itr, logs_dir, tb_writer, args, wandb_run):
    _, jkey = jax.random.split(jkey)
    models_ = models.set_params(params, update_trainable_only=True)
    models = models_

    gt_obj = eval_pred_outputs['x_gt']
    x_inf_list = eval_pred_outputs['x_inf_list']
    x_dif_list = eval_pred_outputs['x_dif_list']
    x_pred_list = eval_pred_outputs['x_pred_list']
    ptb_rec_obj_list = eval_pred_outputs['ptb_rec_obj_list']

    rgbs = eval_observations['rgbs']
    seg = eval_observations['seg']
    cam_posquat = eval_observations['cam_posquat']
    intrinsic = eval_observations['intrinsic']

    nv = rgbs.shape[-4]

    vis_idx = 0
    # vp_idx = np.random.randint(0, nv)
    vp_idx = 0
    vis_intrinsic = intrinsic[vis_idx,vp_idx]
    vis_cam_posquat = cam_posquat[vis_idx,vp_idx]
    pixel_size = (int(vis_intrinsic[1]), int(vis_intrinsic[0]))
    vis_cam_pos, vis_cam_quat = vis_cam_posquat[:3], vis_cam_posquat[3:]
    sdf_func = partial(rutil.scene_sdf, models=models_)
    latent_render_func = lambda: jax.jit(partial(rutil.cvx_render_scene, models=models_, sdf=sdf_func, pixel_size=pixel_size, 
                                                    intrinsic=vis_intrinsic, camera_pos=vis_cam_pos, camera_quat=vis_cam_quat, seg_out=True))
    latent_render_func_dict = {}
    for i in range(1, args.nparticles+1):
        latent_render_func_dict[i] = latent_render_func()
    gt_obj_sq = gt_obj[vis_idx]
    # object to pixel
    def render_img(obj:loutil.LatentObjects):
        if obj.nobj != 1:
            obj = obj[obj.obj_valid_mask]
        if obj.outer_shape[0] == 0:
            return np.zeros((*pixel_size, 3)), np.zeros(pixel_size)
        
        # temporaly turn off rendering
        if args.fps_only:
            rgb_, seg_ = np.zeros((*pixel_size, 3)), np.zeros(pixel_size)
        else:
            rgb_, seg_ = latent_render_func_dict[obj.outer_shape[0]](obj)
        
        pixel_coord, out_ = cutil.global_pnts_to_pixel(vis_intrinsic, vis_cam_posquat, obj.fps_tf) # (... NR)
        
        # add more pixel_coord
        if args.fps_only:
            pixel_coord = pixel_coord[...,None,:] + np.array([[[0,0],[0,1],[1,0],[0,-1],[-1,0]]])
        else:
            pixel_coord = pixel_coord[...,None,:] + np.array([[[0,0]]])

        pixel_coord = np.array(pixel_coord).astype(np.int32).clip(0, vis_intrinsic[:2][::-1]-1).astype(np.int32)
        pixel_coord = np.where(np.isnan(pixel_coord), 0, pixel_coord)
        pixel_coord = np.where(pixel_coord<0, 0, pixel_coord)
        pixel_coord = np.where(pixel_coord>10000, 0, pixel_coord)

        pixel_coord = pixel_coord.reshape(pixel_coord.shape[0], -1, 2)

        rgb_ = np.array(rgb_)
        for i in range(pixel_coord.shape[0]):
            rgb_[pixel_coord[i,:,0], pixel_coord[i,:,1]] = np.ones(3)
        return rgb_, seg_

    cond = models_.apply('img_encoder', rgbs, cam_posquat, intrinsic, dino_feat_out=args.train_seg_only==1)
    if args.train_seg:
        seg_pred = models_.apply('seg_predictor', cond)
        seg_pred = nn.sigmoid(seg_pred)
        seg_label = seg.astype(jnp.float32)
        seg_iou = jnp.sum(seg_pred.squeeze(-1) * seg_label)/(jnp.sum(seg_pred.squeeze(-1) + seg_label - seg_pred.squeeze(-1) * seg_label) + 1e-6)
        seg_pred_vis = seg_pred[vis_idx, vp_idx]
    else:
        seg_iou = 0
    

    rgb_gt = rgbs[vis_idx][vp_idx] if rgbs is not None else np.zeros((*pixel_size, 3))
    
    def rgb_mix(rgb, pred, seg=None):
        alpha = 0.65
        if seg is not None:
            pred_ = np.where(seg[...,None]>=0, pred, 0)
        else:
            pred_ = pred
        return (1-alpha)*rgb/255. + pred_*alpha
    
    if args.train_seg_only:
        fig_ = plt.figure()
        plt.imshow(rgb_mix(rgb_gt, seg_pred_vis))
        fig_.canvas.draw()
        rgb = np.array(fig_.canvas.renderer._renderer)
        plt.close()

        eval_metric_dict = {
            'eval_seg_iou': seg_iou,
        }

        return rgb, eval_metric_dict

    x_inf_list_vis, x_dif_list_vis, x_pred_list_vis = \
        jax.tree_util.tree_map(lambda x: x[vis_idx], (x_inf_list, x_dif_list, x_pred_list))
    
    if itr%(args.save_interval*2) == 0:
        fps = 15 if args.dm_type != 'regression' else 3
        video_logs_dir = os.path.join(logs_dir, 'video')
        os.makedirs(video_logs_dir, exist_ok=True)
        video_arr = []
        with rutil.VideoWriter(os.path.join(video_logs_dir, f'dif_{itr}.mp4'), fps=fps) as vid:
            for t in range(x_dif_list_vis.shape[0]):
                video_arr.append(rgb_mix(rgb_gt, *render_img(x_dif_list_vis[t])))
                vid(video_arr[-1])
        tb_writer.add_video(
            tag='dif_video',
            vid_tensor=np.stack(video_arr)[None].transpose(0,1,4,2,3),
            global_step=itr, fps=fps)
        
        if not args.debug:
            frames = (np.stack(video_arr)*255).astype(np.uint8).transpose(0,3,1,2)

        video_arr = []
        with rutil.VideoWriter(os.path.join(video_logs_dir, f'pred_{itr}.mp4'), fps=fps) as vid:
            for t in range(x_pred_list_vis.shape[0]):
                video_arr.append(rgb_mix(rgb_gt, *render_img(x_pred_list_vis[t])))
                vid(video_arr[-1])
        tb_writer.add_video(
            tag='pred_progress',
            vid_tensor=np.stack(video_arr)[None].transpose(0,1,4,2,3),
            global_step=itr, fps=fps)
        
        if not args.debug:
            frames2 = (np.stack(video_arr)*255).astype(np.uint8).transpose(0,3,1,2)
            wandb_run.log({"dif_video": wandb.Video(frames, fps=fps), "pred_video": wandb.Video(frames2, fps=fps)}, step=itr)

    jkey, subkey = jax.random.split(jkey)
    ptb_rec_obj_list_vis = jax.tree_util.tree_map(lambda x: x[vis_idx], ptb_rec_obj_list)
    vis_obj_list = (gt_obj_sq.drop_gt_info(), *x_inf_list_vis, *ptb_rec_obj_list_vis)

    rgb_list = []
    for vo in vis_obj_list:
        rgb_list.append(render_img(vo)[0])

    fig_ = plt.figure()
    plt.subplot(3,3,1)
    plt.imshow(rgb_gt)
    plt.axis('off')
    plt.subplot(3,3,2)
    plt.imshow(rgbs[vis_idx][vp_idx+1])
    plt.axis('off')
    for i in range(len(rgb_list)):
        plt.subplot(3,3,i+3)
        # plt.imshow(rgb_mix(rgb_gt, rgb_list[i]))
        plt.imshow(rgb_list[i])
        # plt.title(f'{obs_values[i]:0.3f}')
        plt.axis('off')
    # plt.close()
            
    fig_.canvas.draw()
    rgb = np.array(fig_.canvas.renderer._renderer)
    plt.close()

    # feature renderer imgs
    if args.render_loss_weight!=0:
        output_pixel_size = (pixel_size[0]//2, pixel_size[1]//2)
        cond_frender:structs.ImgFeatures = models_.apply('spatial_PE', cond, train=False)

        env_oriCORNs_target:loutil.LatentObjects = loutil.LatentObjects().init_obj_info(eval_batch["env_info"], jax.lax.stop_gradient(models_.mesh_aligned_canonical_obj), models_.rot_configs)
        env_oriCORNs_target = jax.lax.stop_gradient(env_oriCORNs_target[:,-1:]) # only sink
        oriCORNs_target = gt_obj.concat(env_oriCORNs_target, axis=1)
        oriCORN_rendered, _ = models_.apply('feature_renderer', oriCORNs_target, cond_frender, cond_frender.cam_posquat[...,-1,:], 
                                            cond_frender.intrinsic[...,-1,:], output_pixel_size, train=False)
        rgb_nv_target = eval_batch["rgbs"][...,-1,:,:,:]
        if rgb_nv_target.dtype == jnp.uint8:
            rgb_nv_target = rgb_nv_target.astype(jnp.float32)/255.
        rgb_nv_target = jax.image.resize(rgb_nv_target, rgb_nv_target.shape[:-3] + output_pixel_size + (3,), method='nearest')
        oriCORN_render_vis = jnp.stack([oriCORN_rendered[:1], rgb_nv_target[:1]], axis=0)
        oriCORN_render_vis = einops.rearrange(oriCORN_render_vis, 'B N C H W -> (B C) (N H) W')
        oriCORN_render_vis = np.array(oriCORN_render_vis)
        rgb_vis = oriCORN_render_vis
    else:
        rgb_vis = rgb

    return rgb_vis



def generate_dif_video(logs_dir, models:mutil.Models, pred_oriCORNs:typing.List[loutil.LatentObjects], rgb_vis, vis_cam_posquat, vis_intrinsic, 
                    #    final_oriCORNs:loutil.LatentObjects, 
                       eval_batch_ds=None, fps_only=False):

    if isinstance(pred_oriCORNs, loutil.LatentObjects):
        if len(pred_oriCORNs.shape) == 3:
            pred_oriCORNs = pred_oriCORNs[0]
        pred_oriCORNs_list = []
        for i in range(pred_oriCORNs.shape[0]):
            pred_oriCORNs_list.append(pred_oriCORNs[i])
        pred_oriCORNs = pred_oriCORNs_list

    # if len(final_oriCORNs.shape) == 2:
    #     final_oriCORNs = final_oriCORNs[0]
    
    # pred_oriCORNs.append(final_oriCORNs)

    if eval_batch_ds is not None:
        vis_idx = 0
        vp_idx = 0
        intrinsic = eval_batch_ds["cam_info"]['cam_intrinsics'].astype(np.float32)
        cam_posquat = eval_batch_ds["cam_info"]['cam_posquats'].astype(np.float32)
        vis_intrinsic = intrinsic[vis_idx,vp_idx]
        vis_cam_posquat = cam_posquat[vis_idx,vp_idx]
        rgb_vis = eval_batch_ds["rgbs"][vis_idx][vp_idx]
    pixel_size = (int(vis_intrinsic[1]), int(vis_intrinsic[0]))
    vis_cam_pos, vis_cam_quat = vis_cam_posquat[:3], vis_cam_posquat[3:]
    sdf_func = partial(rutil.scene_sdf, models=models)
    latent_render_func = lambda: jax.jit(partial(rutil.cvx_render_scene, models=models, sdf=sdf_func, pixel_size=pixel_size, 
                                                    intrinsic=vis_intrinsic, camera_pos=vis_cam_pos, camera_quat=vis_cam_quat, seg_out=True))
    latent_render_func_dict = {}
    for i in range(1, 15):
        latent_render_func_dict[i] = latent_render_func()
    # object to pixel
    def render_img(obj:loutil.LatentObjects):
        if obj.nobj != 1:
            obj = obj[obj.obj_valid_mask]
        if obj.outer_shape[0] == 0:
            return np.zeros((*pixel_size, 3)), np.zeros(pixel_size)
        
        # temporaly turn off rendering
        # rgb_, seg_ = np.zeros((*pixel_size, 3)), np.zeros(pixel_size)
        if fps_only:
            rgb_, seg_ = np.zeros((*pixel_size, 3)), np.zeros(pixel_size)
        else:
            rgb_, seg_ = latent_render_func_dict[obj.outer_shape[0]](obj)
        
        pixel_coord, out_ = cutil.global_pnts_to_pixel(vis_intrinsic, vis_cam_posquat, obj.fps_tf) # (... NR)
        
        # add more pixel_coord
        # pixel_coord = pixel_coord[...,None,:] + np.array([[[0,0],[0,1],[1,0],[0,-1],[-1,0]]])
        pixel_coord = pixel_coord[...,None,:]

        pixel_coord = np.array(pixel_coord).astype(np.int32).clip(0, vis_intrinsic[:2][::-1]-1).astype(np.int32)
        pixel_coord = np.where(np.isnan(pixel_coord), 0, pixel_coord)
        pixel_coord = np.where(pixel_coord<0, 0, pixel_coord)
        pixel_coord = np.where(pixel_coord>10000, 0, pixel_coord)

        pixel_coord = pixel_coord.reshape(pixel_coord.shape[0], -1, 2)

        rgb_ = np.array(rgb_)
        # for i in range(pixel_coord.shape[0]):
        #     rgb_[pixel_coord[i,:,0], pixel_coord[i,:,1]] = np.ones(3)
        return rgb_, seg_, pixel_coord

    # sampling
        
    def rgb_mix(rgb, pred, seg=None, pixel_coord=None):
        alpha = 0.65
        if seg is not None:
            pred_ = np.where(seg[...,None]>=0, pred, 0)
        else:
            pred_ = pred
        rgb_res = (1-alpha)*rgb/255. + pred_*alpha
        rgb_res = np.array(rgb_res)
        if pixel_coord is not None:
            for i in range(pixel_coord.shape[0]):
                rgb_res[pixel_coord[i,:,0], pixel_coord[i,:,1]] = np.ones(3)
        return rgb_res
    
    # fps = 15 if args.dm_type != 'regression' else 3
    fps = len(pred_oriCORNs)/4
    # video_logs_dir = os.path.join(logs_dir, 'video')
    os.makedirs(os.path.dirname(logs_dir), exist_ok=True)
    video_arr = []
    # with rutil.VideoWriter(os.path.join(logs_dir, f'dif.mp4'), fps=fps) as vid:
    with rutil.VideoWriter(logs_dir, fps=fps) as vid:
        for t in pred_oriCORNs:
            video_arr.append(rgb_mix(rgb_vis, *render_img(t)))
            vid(video_arr[-1])
    # frames = (np.stack(video_arr)*255).astype(np.uint8).transpose(0,3,1,2)



def aligne_two_obj_wrt_obj1(obj1:loutil.LatentObjects, obj2:loutil.LatentObjects, args)->typing.Tuple[loutil.LatentObjects, jnp.ndarray]:
    obj2_valid_mask = obj2.obj_valid_mask

    # Matching and loss calculation
    _, obj2_aligned, obj_matching_pair, matched_obj2_valid_mask, obj_matching_cost = \
                bputil.obj_matching(obj1, obj2, target_obj_valid_mask=obj2_valid_mask, 
                    pos_loss_coef=args.pos_loss_coef, dc_pos_loss_coef=args.dc_pos_loss_coef,
                    base_order='pred',
                    fps_only=args.fps_only,
                    dif_type=args.train_loss_type,
                    # dif_type='gbp' if args.train_loss_type in ['hg_sp', 'hg'] else args.train_loss_type,
                    callback=None)

    # calculate obj dif
    fps_matched_pair, _, _ = bputil.fps_matching(obj1, obj2_aligned, dif_type=args.train_loss_type, fps_only=args.fps_only)
    obj2_aligned = obj2_aligned.replace(rel_fps=jnp.take_along_axis(obj2_aligned.rel_fps, fps_matched_pair[...,1:2], axis=-2),
                                        z=jnp.take_along_axis(obj2_aligned.z, fps_matched_pair[...,1:2,None], axis=-3))

    return obj2_aligned, matched_obj2_valid_mask

def cal_vel_dif(x_vel_pred:loutil.LatentObjects, x0_gt:loutil.LatentObjects, x0_noise:loutil.LatentObjects, conf_pred, args, loss_type='l1'):
    x0_gt_aligned, obj2_valid_mask_aligned = aligne_two_obj_wrt_obj1(x0_noise, x0_gt, args)
    vel_gt = x0_gt_aligned.set_h(x0_gt_aligned.h - x0_noise.h)

    em_loss_log = {}

    loss_func = (lambda x, y: jnp.abs(x - y).sum(axis=-1)) if loss_type=='l1' else None
    pos_loss_coef=np.sqrt(args.pos_loss_coef) if loss_type=='l1' else args.pos_loss_coef
    dc_pos_loss_coef=np.sqrt(args.dc_pos_loss_coef) if loss_type=='l1' else args.dc_pos_loss_coef
    if args.fps_only:
        ch_z = jnp.sum(loss_func(x_vel_pred.z_flat, vel_gt.z_flat), axis=-1)
    else:
        ch_z = 0
    ch_fps = jnp.sum(loss_func(x_vel_pred.fps_tf, vel_gt.fps_tf), axis=-1)
    pos_dif = loss_func(x_vel_pred.pos, vel_gt.pos)
    dif_loss = dc_pos_loss_coef*ch_fps + ch_z + pos_loss_coef*pos_dif
    em_loss_log['pos_dif'] = pos_dif
    em_loss_log['ch_fps'] = ch_fps
    em_loss_log['ch_z'] = ch_z


    dif_loss, em_loss_log = jax.tree_util.tree_map(
        lambda x: jnp.sum(obj2_valid_mask_aligned * x, axis=-1), (dif_loss, em_loss_log))
    
    if args.nparticles > 1 and conf_pred is not None:
        # Confidence loss - label should be ordered by pred obj (0,1,2,3...)
        conf_label = obj2_valid_mask_aligned[...,None].astype(jnp.float32)

        if args.single_ds:
            conf_label = jnp.ones_like(conf_label)
        assert conf_pred.shape == conf_label.shape
        conf_pred_sig = jax.nn.sigmoid(conf_pred).clip(1e-5, 1 - 1e-5)
        conf_loss_ = conf_label * jnp.log(conf_pred_sig) + (1 - conf_label) * jnp.log(1 - conf_pred_sig)
        conf_loss_ = -jnp.sum(conf_loss_, axis=(-1, -2))  # Sum over objects
        conf_loss_ = jnp.mean(conf_loss_)

        conf_inconsistency_cnt = jnp.sum(jnp.abs(jnp.sum(conf_label.squeeze(-1), axis=(-1,)) - jnp.sum(obj2_valid_mask_aligned, axis=-1)))
    else:
        conf_inconsistency_cnt = 0
        conf_loss_ = 0.0
    
    em_loss_log['conf_inconsistency_cnt'] = conf_inconsistency_cnt

    return dif_loss, conf_loss_, em_loss_log

def cal_xpred_dif(x0_pred:loutil.LatentObjects, x0_previous:loutil.LatentObjects, conf_pred, x0_gt:loutil.LatentObjects, args, loss_type='l1'):
    target_obj_valid_mask = x0_gt.obj_valid_mask
    if x0_previous is None:
        x0_previous = x0_pred

    # Matching and loss calculation
    _, x0_target_selected, obj_matching_pair, matched_obj_valid_target_mask, obj_matching_cost = \
                bputil.obj_matching(x0_previous, x0_gt, target_obj_valid_mask=target_obj_valid_mask, 
                    pos_loss_coef=args.pos_loss_coef, dc_pos_loss_coef=args.dc_pos_loss_coef,
                    base_order='pred',
                    fps_only=args.fps_only,
                    dif_type=args.train_loss_type,
                    # dif_type='gbp' if args.train_loss_type in ['hg_sp', 'hg'] else args.train_loss_type,
                    callback=None)

    # calculate obj dif
    fps_matched_pair, _, _ = bputil.fps_matching(x0_previous, x0_target_selected, dif_type=args.train_loss_type, fps_only=args.fps_only)

    dif_loss_, em_loss_log = bputil.obj_dif(x0_pred, x0_target_selected, 
                                        dif_type=args.train_loss_type,
                                        pos_loss_coef=np.sqrt(args.pos_loss_coef) if loss_type=='l1' else args.pos_loss_coef,
                                        dc_pos_loss_coef=np.sqrt(args.dc_pos_loss_coef) if loss_type=='l1' else args.dc_pos_loss_coef,
                                        loss_func=(lambda x, y: jnp.abs(x - y).sum(axis=-1)) if loss_type=='l1' else None,
                                        # pos_loss_coef=args.pos_loss_coef, 
                                        # dc_pos_loss_coef=args.dc_pos_loss_coef,
                                        # loss_func = None,
                                        fps_matched_pair=fps_matched_pair,
                                        fps_only=args.fps_only,
                                        )
    dif_loss_, em_loss_log = jax.tree_util.tree_map(
        lambda x: jnp.sum(matched_obj_valid_target_mask * x, axis=-1), (dif_loss_, em_loss_log))

    # aux_info['fps_matched_pair']=fps_matched_pair

    if args.nparticles > 1 and conf_pred is not None:
        # Confidence loss - label should be ordered by pred obj (0,1,2,3...)
        conf_label = jnp.where(matched_obj_valid_target_mask[..., None], obj_matching_pair[..., 0:1], -1) == jnp.arange(x0_pred.nobj)
        conf_label = jnp.any(conf_label, axis=-2).astype(jnp.float32)[..., None]

        if args.single_ds:
            conf_label = jnp.ones_like(conf_label)
        assert conf_pred.shape == conf_label.shape
        conf_pred_sig = jax.nn.sigmoid(conf_pred).clip(1e-5, 1 - 1e-5)
        conf_loss_ = conf_label * jnp.log(conf_pred_sig) + (1 - conf_label) * jnp.log(1 - conf_pred_sig)
        conf_loss_ = -jnp.sum(conf_loss_, axis=(-1, -2))  # Sum over objects
        conf_loss_ = jnp.mean(conf_loss_)

        conf_inconsistency_cnt = jnp.sum(jnp.abs(jnp.sum(conf_label.squeeze(-1), axis=(-1,)) - jnp.sum(target_obj_valid_mask, axis=-1)))
        # if args.debug:
        #     jdb.callback(debug_callback, (conf_pred, conf_label.squeeze(-1), matched_obj_valid_target_mask, obj_matching_pair, target_obj_valid_mask))
    else:
        conf_inconsistency_cnt = 0
        conf_loss_ = 0.0
    
    em_loss_log['conf_inconsistency_cnt'] = conf_inconsistency_cnt

    return dif_loss_, conf_loss_, em_loss_log

if __name__ == "__main__":
    class Args:
        pass
    args = Args()
    args.pos_loss_coef = 1.0
    args.dc_pos_loss_coef = 1.0
    latent_shape = (32,16,8)
    df_cost = DFCost(latent_shape, args)
    na = 10

    jkey = jax.random.PRNGKey(0)
    h12 = jax.random.normal(jkey, shape=(2,na,latent_shape[0]*(latent_shape[1]*latent_shape[2]+3)+3))

    geom_ott = pointcloud.PointCloud(h12[0], h12[1], cost_fn=trutil.CFCost(latent_shape=latent_shape, args=args)) # Chamfer distance

    trutil.CFCost(latent_shape=latent_shape, args=args).all_pairs_pairwise(h12[0], h12[1])
    # cost_val = df_cost.pairwise(h12[0], h12[1])

    # self.cost_fn.all_pairs_pairwise(self.x, self.y)

    geom_ott
    print(1)
