from model.mdm import MDM
from model.mdm_ours import MDM as MDM_Ours
from model.mdm_ours import MDMV3 as MDM_Ours_V3
from model.mdm_ours import MDMV4 as MDM_Ours_V4
from model.mdm_ours import MDMV5 as MDM_Ours_V5
from model.mdm_ours import MDMV6 as MDM_Ours_V6
from model.mdm_ours import MDMV7 as MDM_Ours_V7
from model.mdm_ours import MDMV8 as MDM_Ours_V8
from model.mdm_ours import MDMV9 as MDM_Ours_V9
from model.mdm_ours import MDMV10 as MDM_Ours_V10
from model.mdm_ours import MDMV11 as MDM_Ours_V11
# MDM_Ours_V12
from model.mdm_ours import MDMV12 as MDM_Ours_V12
# MDM_Ours_V13
from model.mdm_ours import MDMV13 as MDM_Ours_V13
# MDM_Ours_V14
from model.mdm_ours import MDMV14 as MDM_Ours_V14
from diffusion import gaussian_diffusion as gd
from diffusion.respace import SpacedDiffusion, space_timesteps
from utils.parser_util import get_cond_mode
import torch
from torch import optim, nn
import torch.nn.functional as F
from manopth.manolayer import ManoLayer
import numpy as np
import trimesh
import os
from diffusion.respace_ours import SpacedDiffusion as SpacedDiffusion_Ours
# SpacedDiffusionV2
from diffusion.respace_ours import SpacedDiffusionV2 as SpacedDiffusion_OursV2
from diffusion.respace_ours import SpacedDiffusionV3 as SpacedDiffusion_OursV3
# SpacedDiffusionV4
from diffusion.respace_ours import SpacedDiffusionV4 as SpacedDiffusion_OursV4
# SpacedDiffusion_OursV5
from diffusion.respace_ours import SpacedDiffusionV5 as SpacedDiffusion_OursV5
# SpacedDiffusion_OursV6
from diffusion.respace_ours import SpacedDiffusionV6 as SpacedDiffusion_OursV6
# SpacedDiffusion_OursV7
from diffusion.respace_ours import SpacedDiffusionV7 as SpacedDiffusion_OursV7
from diffusion.respace_ours import SpacedDiffusionV9 as SpacedDiffusion_OursV9



def batched_index_select_ours(values, indices, dim = 1):
    value_dims = values.shape[(dim + 1):]
    values_shape, indices_shape = map(lambda t: list(t.shape), (values, indices))
    indices = indices[(..., *((None,) * len(value_dims)))]
    indices = indices.expand(*((-1,) * len(indices_shape)), *value_dims)
    value_expand_len = len(indices_shape) - (dim + 1)
    values = values[(*((slice(None),) * dim), *((None,) * value_expand_len), ...)]

    value_expand_shape = [-1] * len(values.shape)
    expand_slice = slice(dim, (dim + value_expand_len))
    value_expand_shape[expand_slice] = indices.shape[expand_slice]
    values = values.expand(*value_expand_shape)

    dim += value_expand_len
    return values.gather(dim, indices)



def gaussian_entropy(logvar): # gaussian entropy ##
    const = 0.5 * float(logvar.size(1)) * (1. + np.log(np.pi * 2))
    ent = 0.5 * logvar.sum(dim=1, keepdim=False) + const
    return ent


def standard_normal_logprob(z): # feature dim
    dim = z.size(-1) # dim size -1
    log_z = -0.5 * dim * np.log(2 * np.pi)
    return log_z - z.pow(2) / 2


def load_multiple_models_fr_path(model_path, model):
    model_paths = model_path.split(";")
    print(f"Loading multiple models with split model_path: {model_paths}")
    setting_to_model_path = {}
    for cur_path in model_paths:
        cur_setting_nm, cur_model_path = cur_path.split(':')
        setting_to_model_path[cur_setting_nm] = cur_model_path
    loaded_dict = {}
    for cur_setting in setting_to_model_path:
        cur_model_path = setting_to_model_path[cur_setting]
        cur_model_state_dict = torch.load(cur_model_path, map_location='cpu')
        if cur_setting == 'diff_realbasejtsrel':
            interested_keys = [
                'real_basejtsrel_input_process', 'real_basejtsrel_sequence_pos_encoder', 'real_basejtsrel_seqTransEncoder', 'real_basejtsrel_embed_timestep', 'real_basejtsrel_sequence_pos_denoising_encoder', 'real_basejtsrel_denoising_seqTransEncoder', 'real_basejtsrel_output_process'
            ]
        elif cur_setting == 'diff_basejtsrel':
            interested_keys = [
                'avg_joints_sequence_input_process', 'joints_offset_input_process', 'sequence_pos_encoder', 'seqTransEncoder', 'logvar_seqTransEncoder', 'embed_timestep', 'basejtsrel_denoising_embed_timestep', 'sequence_pos_denoising_encoder', 'basejtsrel_denoising_seqTransEncoder', 'basejtsrel_glb_denoising_latents_trans_layer', 'avg_joint_sequence_output_process', 'joint_offset_output_process', 'output_process'
            ]
        elif cur_setting == 'diff_realbasejtsrel_to_joints':
            interested_keys = [
                'real_basejtsrel_to_joints_input_process', 'real_basejtsrel_to_joints_sequence_pos_encoder', 'real_basejtsrel_to_joints_seqTransEncoder', 'real_basejtsrel_to_joints_embed_timestep', 'real_basejtsrel_to_joints_sequence_pos_denoising_encoder', 'real_basejtsrel_to_joints_denoising_seqTransEncoder', 'real_basejtsrel_to_joints_output_process', 
            ]
        else:
            raise ValueError(f"cur_setting:{cur_setting} Not implemented yet")
        for k in cur_model_state_dict:
            for cur_inter_key in interested_keys:
                if cur_inter_key in k:
                    loaded_dict[k] = cur_model_state_dict[k]
    model_dict = model.state_dict()
    model_dict.update(loaded_dict)
    model.load_state_dict(model_dict)
    
            
                


def load_model_wo_clip(model, state_dict): # missing_keys: in the current model but not found in the state_dict? # unexpected_keys: not in the current model but found inthe state_dict? 
    missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
    # print(unexpected_keys)
    assert len(unexpected_keys) == 0
    assert all([k.startswith('clip_model.') for k in missing_keys])

### create model and diffusion ## # 
def create_model_and_diffusion(args, data):
    if args.dataset in ['motion_ours'] and args.rep_type in ["obj_base_rel_dist", "ambient_obj_base_rel_dist"]:
        model = MDM_Ours(**get_model_args(args, data))
    elif args.dataset in ['motion_ours'] and args.rep_type in ["obj_base_rel_dist_we"]:
        model = MDM_Ours_V3(**get_model_args(args, data))
    # MDM_Ours_V4
    elif args.dataset in ['motion_ours'] and args.rep_type in ["obj_base_rel_dist_we_wj"]:
        model = MDM_Ours_V4(**get_model_args(args, data))
    # obj_base_rel_dist_we_wj_latents
    elif args.dataset in ['motion_ours'] and args.rep_type in ["obj_base_rel_dist_we_wj_latents"]:
        if args.diff_spatial:
            if args.pred_joints_offset:
                if args.diff_joint_quants:
                    model =  MDM_Ours_V13(**get_model_args(args, data))
                elif args.diff_hand_params:
                    model =  MDM_Ours_V14(**get_model_args(args, data))
                else:
                    if args.finetune_with_cond:
                        print(f"Using MDM ours V12!!!!")
                        model =  MDM_Ours_V12(**get_model_args(args, data))
                    else:
                        print(f"Using MDM ours V10!!!!")
                        model =  MDM_Ours_V10(**get_model_args(args, data))
            else:
                print(f"Using MDM ours V9!!!!")
                model =  MDM_Ours_V9(**get_model_args(args, data))
        elif args.diff_latents:
            print(f"Using MDM ours V11!!!!")
            model =  MDM_Ours_V11(**get_model_args(args, data))
        elif args.use_sep_models:
            if args.use_vae:
                if args.pred_basejtsrel_avgjts:
                    print(f"Using MDM ours V8!!!!")
                    model = MDM_Ours_V8(**get_model_args(args, data))
                else:
                    model = MDM_Ours_V7(**get_model_args(args, data))
            else:
                model = MDM_Ours_V6(**get_model_args(args, data))
        else:
            model = MDM_Ours_V5(**get_model_args(args, data))
    else:
        model = MDM(**get_model_args(args, data))
    diffusion = create_gaussian_diffusion(args)
    return model, diffusion

# give utils to models #
def get_model_args(args, data):
    # default_args
    clip_version = 'ViT-B/32'
    action_emb = 'tensor' ## get model arguments ##
    cond_mode = get_cond_mode(args)
    if hasattr(data.dataset, 'num_actions'):
        num_actions = data.dataset.num_actions
    else:
        num_actions = 1

    # SMPL defaults
    data_rep = 'rot6d'
    njoints = 25
    nfeats = 6

    if args.dataset in ['humanml']: ## from 
        data_rep = 'hml_vec'
        njoints = 263 # joints
        nfeats = 1
    elif args.dataset in ['motion_ours']:
        data_rep = 'xyz'
        njoints = 21
        nfeats = 3
    elif args.dataset == 'kit':
        data_rep = 'hml_vec'
        njoints = 251
        nfeats = 1
    ## modeltype; 
    return {'modeltype': '', 'njoints': njoints, 'nfeats': nfeats, 'num_actions': num_actions,
            'translation': True, 'pose_rep': 'rot6d', 'glob': True, 'glob_rot': True,
            'latent_dim': args.latent_dim, 'ff_size': 1024, 'num_layers': args.layers, 'num_heads': 4,
            'dropout': 0.1, 'activation': "gelu", 'data_rep': data_rep, 'cond_mode': cond_mode,
            'cond_mask_prob': args.cond_mask_prob, 'action_emb': action_emb, 'arch': args.arch,
            'emb_trans_dec': args.emb_trans_dec, 'clip_version': clip_version, 'dataset': args.dataset, 'args': args}

def optimize_sampled_hand_joints(sampled_joints, rel_base_pts_to_joints, dists_base_pts_to_joints, base_pts, base_normals):
    # sampled_joints: bsz x ws x nnj x 3
    # signed distances 
    # smoothness 
    bsz, ws, nnj = sampled_joints.shape[:3]
    device = sampled_joints.device
    coarse_lr = 0.1
    num_iters = 100 # if i_iter > 0 else 1 ## nn-coarse-iters for global transformations #
    mano_path = "./data/mano_models/mano/models"
    
    base_pts_exp = base_pts.unsqueeze(1).repeat(1, ws, 1, 1).contiguous()
    base_normals_exp = base_normals.unsqueeze(1).repeat(1, ws, 1, 1).contiguous()
    
    signed_dist_e_coeff = 1.0
    signed_dist_e_coeff = 0.0
    
    
    ### start optimization ###
    # setup MANO layer
    mano_layer = ManoLayer(
        flat_hand_mean=True,
        side='right',
        mano_root=mano_path, # mano_path for the mano model #
        ncomps=24,
        use_pca=True,
        root_rot_mode='axisang',
        joint_rot_mode='axisang'
    ).to(device)
    
    ## random init variables ##
    beta_var = torch.randn([bsz, 10]).to(device)
    rot_var = torch.randn([bsz * ws, 3]).to(device)
    theta_var = torch.randn([bsz * ws, 24]).to(device)
    transl_var = torch.randn([bsz * ws, 3]).to(device)
    
    beta_var.requires_grad_()
    rot_var.requires_grad_()
    theta_var.requires_grad_()
    transl_var.requires_grad_()
    opt = optim.Adam([rot_var, transl_var], lr=coarse_lr)
    for i_iter in range(num_iters):
        opt.zero_grad()
        # mano_layer #
        hand_verts, hand_joints = mano_layer(torch.cat([rot_var, theta_var], dim=-1),
            beta_var.unsqueeze(1).repeat(1, ws, 1).view(-1, 10), transl_var)
        hand_verts = hand_verts.view(bsz, ws, 778, 3) * 0.001 ## bsz x ws x nn
        hand_joints = hand_joints.view(bsz, ws, -1, 3) * 0.001
        
        ### === e1 should be close to predicted values === ###
        # bsz x ws x nnj x nnb x 3 #
        rel_base_pts_to_hand_joints = hand_joints.unsqueeze(-2) - base_pts.unsqueeze(1).unsqueeze(1)
        # bs zx ws x nnj x nnb # 
        signed_dist_base_pts_to_hand_joints = torch.sum(
            rel_base_pts_to_hand_joints * base_normals.unsqueeze(1).unsqueeze(1), dim=-1
        )
        rel_e = torch.sum(
            (rel_base_pts_to_hand_joints - rel_base_pts_to_joints) ** 2, dim=-1
        ).mean()
        if dists_base_pts_to_joints is not None:
            dist_e = torch.sum(
                (signed_dist_base_pts_to_hand_joints - dists_base_pts_to_joints) ** 2, dim=-1
            ).mean()
        else:
            dist_e = torch.zeros((1,), dtype=torch.float32).to(device).mean()
        
        
        ### === e2 the signed distances to nearest points should not be negative to the neareste === ###
        ## base_pts: bsz x nn_base_pts x 3
        ## bsz x ws x nnj x 1 x 3 -- bsz x 1 x 1 x nnb x 3 ##
        ## bsz x ws x nnj x nnb ##
        
        ''' strategy 2: use all base pts, rel, dists for resolving '''
        # rel_base_pts_to_hand_joints: bsz x ws x nnj x nnb x 3 #
        signed_dist_mask = signed_dist_base_pts_to_hand_joints < 0.
        l2_dist_rel_joints_to_base_pts_mask = torch.sqrt(
            torch.sum(rel_base_pts_to_hand_joints ** 2, dim=-1)
        ) < 0.05
        signed_dist_mask = (signed_dist_mask.float() + l2_dist_rel_joints_to_base_pts_mask.float()) > 1.5
        dot_rel_with_normals = torch.sum(
            rel_base_pts_to_hand_joints * base_normals.unsqueeze(1).unsqueeze(1), dim=-1
        )
        signed_dist_mask = signed_dist_mask.detach() # detach the mask #

        # dot_rel_with_normals: bsz x ws x nnj x nnb
        avg_masks = (signed_dist_mask.float()).sum(dim=-1).mean()
        
        signed_dist_e = dot_rel_with_normals * signed_dist_base_pts_to_hand_joints
        signed_dist_e = torch.sum(
            signed_dist_e[signed_dist_mask]
        ) / torch.clamp(torch.sum(signed_dist_mask.float()), min=1e-5).item()
        ###### ====== get loss for signed distances ==== ###
        ''' strategy 2: use all base pts, rel, dists for resolving '''
        
        
        
        ''' strategy 1: use nearest base pts, rel, dists for resolving '''
        # dist_rhand_joints_to_base_pts = torch.sum(
        #     (hand_joints.unsqueeze(-2) - base_pts.unsqueeze(1).unsqueeze(1)) ** 2, dim=-1
        # )
        # # minn_dists_idxes: bsz x ws x nnj -->  
        # minn_dists_to_base_pts, minn_dists_idxes = torch.min(
        #     dist_rhand_joints_to_base_pts, dim=-1
        # )
        # # base_pts: bsz x nn_base_pts x 3 #
        # # base_pts: bsz x ws x nn_base_pts x 3 #
        # # bsz x ws x nnj 
        
        # # object verts and object faces #
        # ## other than the sampling process; not 
        # # bsz x ws x nnj x 3 ##
        # nearest_base_pts = batched_index_select_ours(
        #     base_pts_exp, indices=minn_dists_idxes, dim=2
        # )
        # # bsz x ws x nnj x 3 # # base normalse #
        # nearest_base_normals = batched_index_select_ours(
        #     base_normals_exp, indices=minn_dists_idxes, dim=2
        # )  
        # # bsz x ws x nnj x 3 #  # the nearest distance points may be of some ambiguous 
        # rel_joints_to_nearest_base_pts = hand_joints - nearest_base_pts
        # # bsz x ws x nnj #
        # signed_dist_joints_to_base_pts = torch.sum(
        #     rel_joints_to_nearest_base_pts * nearest_base_normals, dim=-1
        # )
        # # should not be negative
        # signed_dist_mask = signed_dist_joints_to_base_pts < 0.
        # l2_dist_rel_joints_to_nearest_base_pts_mask = torch.sqrt(
        #     torch.sum(rel_joints_to_nearest_base_pts ** 2, dim=-1)
        # ) < 0.05
        # signed_dist_mask = (signed_dist_mask.float() + l2_dist_rel_joints_to_nearest_base_pts_mask.float()) > 1.5
        # ### ==== mean of signed distances ==== ###
        # signed_dist_e = torch.sum( # penetration 
        #     -1.0 * signed_dist_joints_to_base_pts[signed_dist_mask]
        # ) / torch.clamp(
        #     torch.sum(signed_dist_mask.float()), min=1e-5
        # ).item()
        ''' strategy 1: use nearest base pts, rel, dists for resolving '''
        
        ## === e3 smoothness and prior losses === ##
        pose_smoothness_loss = F.mse_loss(theta_var.view(bsz, ws, -1)[:, 1:], theta_var.view(bsz, ws, -1)[:, :-1])
        shape_prior_loss = torch.mean(beta_var**2)
        pose_prior_loss = torch.mean(theta_var**2)
        ## === e3 smoothness and prior losses === ##
        
        ## === e4 hand joints should be close to sampled hand joints === ##
        dist_dec_jts_to_sampled_pts = torch.sum(
            (hand_joints - sampled_joints) ** 2, dim=-1
        ).mean()
        
        ### signed distance coeff -> the distance coeff #
        loss = pose_smoothness_loss * 0.05 + shape_prior_loss*0.001 + pose_prior_loss * 0.0001 + signed_dist_e * signed_dist_e_coeff + rel_e + dist_e + dist_dec_jts_to_sampled_pts
        
        loss.backward()
        opt.step()
        
        print('Iter {}: {}'.format(i_iter, loss.item()), flush=True)
        print('\tShape Prior Loss: {}'.format(shape_prior_loss.item()))
        print('\tPose Prior Loss: {}'.format(pose_prior_loss.item()))
        print('\tPose Smoothness Loss: {}'.format(pose_smoothness_loss.item()))
        print('\tsigned_dist_e Loss: {}'.format(signed_dist_e.item()))
        print('\trel_e Loss: {}'.format(rel_e.item()))
        print('\tdist_e Loss: {}'.format(dist_e.item()))
        print('\tdist_dec_jts_to_sampled_pts Loss: {}'.format(dist_dec_jts_to_sampled_pts.item()))
    
    fine_lr = 0.1
    num_iters = 1000
    opt = optim.Adam([rot_var, transl_var, beta_var, theta_var], lr=fine_lr)
    for i_iter in range(num_iters):
        opt.zero_grad()
        # mano_layer #
        hand_verts, hand_joints = mano_layer(torch.cat([rot_var, theta_var], dim=-1),
            beta_var.unsqueeze(1).repeat(1, ws, 1).view(-1, 10), transl_var)
        hand_verts = hand_verts.view(bsz, ws, 778, 3) * 0.001 ## bsz x ws x nn
        hand_joints = hand_joints.view(bsz, ws, -1, 3) * 0.001
        
        ### === e1 should be close to predicted values === ###
        # bsz x ws x nnj x nnb x 3 #
        rel_base_pts_to_hand_joints = hand_joints.unsqueeze(-2) - base_pts.unsqueeze(1).unsqueeze(1)
        # bs zx ws x nnj x nnb # 
        signed_dist_base_pts_to_hand_joints = torch.sum(
            rel_base_pts_to_hand_joints * base_normals.unsqueeze(1).unsqueeze(1), dim=-1
        )
        rel_e = torch.sum(
            (rel_base_pts_to_hand_joints - rel_base_pts_to_joints) ** 2, dim=-1
        ).mean()
        
        # dists_base_pts_to_joints ## dists_base_pts_to_joints ##
        if dists_base_pts_to_joints is not None: ## dists_base_pts_to_joints ##
            dist_e = torch.sum(
                (signed_dist_base_pts_to_hand_joints - dists_base_pts_to_joints) ** 2, dim=-1
            ).mean()
        else:
            dist_e = torch.zeros((1,), dtype=torch.float32).mean()
        
        
        ### === e2 the signed distances to nearest points should not be negative to the neareste === ###
        ## base_pts: bsz x nn_base_pts x 3
        ## bsz x ws x nnj x 1 x 3 -- bsz x 1 x 1 x nnb x 3 ##
        ## bsz x ws x nnj x nnb ##
        dist_rhand_joints_to_base_pts = torch.sum(
            (hand_joints.unsqueeze(-2) - base_pts.unsqueeze(1).unsqueeze(1)) ** 2, dim=-1
        )
        # minn_dists_idxes: bsz x ws x nnj -->  
        minn_dists_to_base_pts, minn_dists_idxes = torch.min(
            dist_rhand_joints_to_base_pts, dim=-1
        )
        # base_pts: bsz x nn_base_pts x 3 #
        # base_pts: bsz x ws x nn_base_pts x 3 #
        # bsz x ws x nnj 
        # base_pts_exp = base_pts.unsqueeze(1).repeat(1, ws, 1, 1).contiguous()
        # bsz x ws x nnj x 3 ##
        nearest_base_pts = batched_index_select_ours(
            base_pts_exp, indices=minn_dists_idxes, dim=2
        )
        # bsz x ws x nnj x 3 #
        nearest_base_normals = batched_index_select_ours(
            base_normals_exp, indices=minn_dists_idxes, dim=2
        )
        # bsz x ws x nnj x 3 #
        rel_joints_to_nearest_base_pts = hand_joints - nearest_base_pts
        # bsz x ws x nnj #
        signed_dist_joints_to_base_pts = torch.sum(
            rel_joints_to_nearest_base_pts * nearest_base_normals, dim=-1
        )
        # should not be negative
        signed_dist_mask = signed_dist_joints_to_base_pts < 0.
        l2_dist_rel_joints_to_nearest_base_pts_mask = torch.sqrt(
            torch.sum(rel_joints_to_nearest_base_pts ** 2, dim=-1)
        ) < 0.05
        signed_dist_mask = (signed_dist_mask.float() + l2_dist_rel_joints_to_nearest_base_pts_mask.float()) > 1.5
        
        ### ==== mean of signed distances ==== ###
        signed_dist_e = torch.sum(
            -1.0 * signed_dist_joints_to_base_pts[signed_dist_mask]
        ) / torch.clamp(
            torch.sum(signed_dist_mask.float()), min=1e-5
        ).item()
        
        ## === e3 smoothness and prior losses === ##
        pose_smoothness_loss = F.mse_loss(theta_var.view(bsz, ws, -1)[:, 1:], theta_var.view(bsz, ws, -1)[:, :-1])
        shape_prior_loss = torch.mean(beta_var**2)
        pose_prior_loss = torch.mean(theta_var**2)
        ## === e3 smoothness and prior losses === ##
        
        ## === e4 hand joints should be close to sampled hand joints === ##
        dist_dec_jts_to_sampled_pts = torch.sum(
            (hand_joints - sampled_joints) ** 2, dim=-1
        ).mean()
        
        loss = pose_smoothness_loss * 0.05 + shape_prior_loss*0.001 + pose_prior_loss * 0.0001 + signed_dist_e * signed_dist_e_coeff + rel_e + dist_e + dist_dec_jts_to_sampled_pts
        
        loss.backward()
        opt.step()
        
        print('Iter {}: {}'.format(i_iter, loss.item()), flush=True)
        print('\tShape Prior Loss: {}'.format(shape_prior_loss.item()))
        print('\tPose Prior Loss: {}'.format(pose_prior_loss.item()))
        print('\tPose Smoothness Loss: {}'.format(pose_smoothness_loss.item()))
        print('\tsigned_dist_e Loss: {}'.format(signed_dist_e.item()))
        print('\trel_e Loss: {}'.format(rel_e.item()))
        print('\tdist_e Loss: {}'.format(dist_e.item()))
        print('\tdist_dec_jts_to_sampled_pts Loss: {}'.format(dist_dec_jts_to_sampled_pts.item()))
    
    
    ### refine the optimization with signed energy ##
    signed_dist_e_coeff = 1.0
    fine_lr = 0.1
    num_iters = 1000
    opt = optim.Adam([rot_var, transl_var, beta_var, theta_var], lr=fine_lr)
    for i_iter in range(num_iters):
        opt.zero_grad()
        # mano_layer #
        hand_verts, hand_joints = mano_layer(torch.cat([rot_var, theta_var], dim=-1),
            beta_var.unsqueeze(1).repeat(1, ws, 1).view(-1, 10), transl_var)
        hand_verts = hand_verts.view(bsz, ws, 778, 3) * 0.001 ## bsz x ws x nn
        hand_joints = hand_joints.view(bsz, ws, -1, 3) * 0.001
        
        ### === e1 should be close to predicted values === ###
        # bsz x ws x nnj x nnb x 3 #
        rel_base_pts_to_hand_joints = hand_joints.unsqueeze(-2) - base_pts.unsqueeze(1).unsqueeze(1)
        # bs zx ws x nnj x nnb # 
        signed_dist_base_pts_to_hand_joints = torch.sum(
            rel_base_pts_to_hand_joints * base_normals.unsqueeze(1).unsqueeze(1), dim=-1
        )
        rel_e = torch.sum(
            (rel_base_pts_to_hand_joints - rel_base_pts_to_joints) ** 2, dim=-1
        ).mean()
        
        # dists_base_pts_to_joints ## dists_base_pts_to_joints ##
        if dists_base_pts_to_joints is not None: ## dists_base_pts_to_joints ##
            dist_e = torch.sum(
                (signed_dist_base_pts_to_hand_joints - dists_base_pts_to_joints) ** 2, dim=-1
            ).mean()
        else:
            dist_e = torch.zeros((1,), dtype=torch.float32).mean()
        
        
        ''' strategy 2: use all base pts, rel, dists for resolving '''
        # rel_base_pts_to_hand_joints: bsz x ws x nnj x nnb x 3 #
        signed_dist_mask = signed_dist_base_pts_to_hand_joints < 0.
        l2_dist_rel_joints_to_base_pts_mask = torch.sqrt(
            torch.sum(rel_base_pts_to_hand_joints ** 2, dim=-1)
        ) < 0.05
        signed_dist_mask = (signed_dist_mask.float() + l2_dist_rel_joints_to_base_pts_mask.float()) > 1.5
        ## === dot rel with normals === ##
        # dot_rel_with_normals = torch.sum(
        #     rel_base_pts_to_hand_joints * base_normals.unsqueeze(1).unsqueeze(1), dim=-1
        # )
        ## === dot rel with normals === ##
        ## === dot rel with rel, strategy 3 === ##
        dot_rel_with_normals = torch.sum(
            -1.0 * rel_base_pts_to_hand_joints * rel_base_pts_to_hand_joints, dim=-1
        )
        ## === dot rel with rel, strategy 3 === ##
        signed_dist_mask = signed_dist_mask.detach() # detach the mask #

        # dot_rel_with_normals: bsz x ws x nnj x nnb
        avg_masks = (signed_dist_mask.float()).sum(dim=-1).mean()
        
        signed_dist_e = dot_rel_with_normals * signed_dist_base_pts_to_hand_joints
        signed_dist_e = torch.sum(
            signed_dist_e[signed_dist_mask]
        ) / torch.clamp(torch.sum(signed_dist_mask.float()), min=1e-5).item()
        ###### ====== get loss for signed distances ==== ###
        ''' strategy 2: use all base pts, rel, dists for resolving '''
        # hard projections for 
        ''' strategy 1: use nearest base pts, rel, dists for resolving '''
        # ### === e2 the signed distances to nearest points should not be negative to the neareste === ###
        # ## base_pts: bsz x nn_base_pts x 3
        # ## bsz x ws x nnj x 1 x 3 -- bsz x 1 x 1 x nnb x 3 ##
        # ## bsz x ws x nnj x nnb ##
        # dist_rhand_joints_to_base_pts = torch.sum(
        #     (hand_joints.unsqueeze(-2) - base_pts.unsqueeze(1).unsqueeze(1)) ** 2, dim=-1
        # )
        # # minn_dists_idxes: bsz x ws x nnj -->  
        # minn_dists_to_base_pts, minn_dists_idxes = torch.min(
        #     dist_rhand_joints_to_base_pts, dim=-1
        # )
        # 
        # # base_pts: bsz x nn_base_pts x 3 #
        # # base_pts: bsz x ws x nn_base_pts x 3 #
        # # bsz x ws x nnj 
        # # base_pts_exp = base_pts.unsqueeze(1).repeat(1, ws, 1, 1).contiguous()
        # # bsz x ws x nnj x 3 ##
        # nearest_base_pts = batched_index_select_ours(
        #     base_pts_exp, indices=minn_dists_idxes, dim=2
        # )
        # # bsz x ws x nnj x 3 #
        # nearest_base_normals = batched_index_select_ours(
        #     base_normals_exp, indices=minn_dists_idxes, dim=2
        # )
        # # bsz x ws x nnj x 3 #
        # rel_joints_to_nearest_base_pts = hand_joints - nearest_base_pts
        # # bsz x ws x nnj #
        # signed_dist_joints_to_base_pts = torch.sum(
        #     rel_joints_to_nearest_base_pts * nearest_base_normals, dim=-1
        # )
        # # should not be negative
        # signed_dist_mask = signed_dist_joints_to_base_pts < 0.
        # ## === luojisiwei and others === ##
        # # l2_dist_rel_joints_to_nearest_base_pts_mask = torch.sqrt(
        # #     torch.sum(rel_joints_to_nearest_base_pts ** 2, dim=-1)
        # # ) < 0.05 
        # ## === luojisiwei and others === ##
        # l2_dist_rel_joints_to_nearest_base_pts_mask = torch.sqrt(
        #     torch.sum(rel_joints_to_nearest_base_pts ** 2, dim=-1)
        # ) < 0.1
        # signed_dist_mask = (signed_dist_mask.float() + l2_dist_rel_joints_to_nearest_base_pts_mask.float()) > 1.5
        
        # ### ==== mean of signed distances ==== ###
        # # signed_dist_e = torch.sum(
        # #     -1.0 * signed_dist_joints_to_base_pts[signed_dist_mask]
        # # ) / torch.clamp(
        # #     torch.sum(signed_dist_mask.float()), min=1e-5
        # # ).item()
        
        # # signed_dist_joints_to_base_pts: bsz x ws x nnj # -> disstances 
        # signed_dist_joints_to_base_pts = signed_dist_joints_to_base_pts.detach()
        # # 
        
        ## penetraition resolving --- strategy 
        # dot_rel_with_normals = torch.sum(
        #     rel_joints_to_nearest_base_pts * nearest_base_normals, dim=-1
        # )
        # signed_dist_mask = signed_dist_mask.detach() # detach the mask #
        # # bsz x ws x nnj --> the loss term
        # ## signed distances 3 #### isgned distance 3 ###
        # ## dotrelwithnormals, ##
        # # # signed_dist_mask -> the distances 
        
        # # dot_rel_with_normals: bsz x ws x nnj x nnb
        # avg_masks = (signed_dist_mask.float()).sum(dim=-1).mean()
        
        
        # signed_dist_e = dot_rel_with_normals * signed_dist_joints_to_base_pts
        # signed_dist_e = torch.sum(
        #     signed_dist_e[signed_dist_mask]
        # ) / torch.clamp(torch.sum(signed_dist_mask.float()), min=1e-5).item()
        # ###### ====== get loss for signed distances ==== ###
        ''' strategy 1: use nearest base pts, rel, dists for resolving '''
        
        ## judeg whether inside the object and only project those one inside of the object
        ## === e3 smoothness and prior losses === ##
        pose_smoothness_loss = F.mse_loss(theta_var.view(bsz, ws, -1)[:, 1:], theta_var.view(bsz, ws, -1)[:, :-1])
        shape_prior_loss = torch.mean(beta_var**2)
        pose_prior_loss = torch.mean(theta_var**2)
        ## === e3 smoothness and prior losses === ##
        
        ## === e4 hand joints should be close to sampled hand joints === ##
        dist_dec_jts_to_sampled_pts = torch.sum(
            (hand_joints - sampled_joints) ** 2, dim=-1
        ).mean()
        
        # shoudl take a 
        # how to proejct the jvertex
        # hwo to project the veretex
        # weighted sum of the projectiondirection
        # weights of each base point
        # atraction field -> should be able to learn the penetration resolving strategy 
        # stochestic penetration resolving strategy #
        
        
        loss = pose_smoothness_loss * 0.05 + shape_prior_loss*0.001 + pose_prior_loss * 0.0001 + signed_dist_e * signed_dist_e_coeff + rel_e + dist_e + dist_dec_jts_to_sampled_pts
        
        loss.backward()
        opt.step()
        
        print('Iter {}: {}'.format(i_iter, loss.item()), flush=True)
        print('\tShape Prior Loss: {}'.format(shape_prior_loss.item()))
        print('\tPose Prior Loss: {}'.format(pose_prior_loss.item()))
        print('\tPose Smoothness Loss: {}'.format(pose_smoothness_loss.item()))
        print('\tsigned_dist_e Loss: {}'.format(signed_dist_e.item()))
        print('\trel_e Loss: {}'.format(rel_e.item()))
        print('\tdist_e Loss: {}'.format(dist_e.item()))
        print('\tdist_dec_jts_to_sampled_pts Loss: {}'.format(dist_dec_jts_to_sampled_pts.item()))
        # avg_masks
        print('\tAvg masks: {}'.format(avg_masks.item()))
    
    
    
    ''' returning sampled_joints ''' 
    sampled_joints = hand_joints
    np.save("optimized_verts.npy", hand_verts.detach().cpu().numpy())
    print(f"Optimized verts saved to optimized_verts.npy")
    return sampled_joints.detach()



def get_obj_trimesh_list(obj_verts, obj_faces):
    tot_trimeshes = []
    tot_n = len(obj_verts)
    for i_obj in range(tot_n):
        cur_obj_verts, cur_obj_faces = obj_verts[i_obj], obj_faces[i_obj]
        if isinstance(cur_obj_verts, torch.Tensor):
            cur_obj_verts = cur_obj_verts.detach().cpu().numpy()
        if isinstance(cur_obj_faces, torch.Tensor):
            cur_obj_faces = cur_obj_faces.detach().cpu().numpy()
        cur_obj_mesh = trimesh.Trimesh(vertices=cur_obj_verts, faces=cur_obj_faces,
        process=False, use_embree=True)
        tot_trimeshes.append(cur_obj_mesh)
    return tot_trimeshes

def judge_penetrated_points(obj_mesh, subj_pts):
    # bsz 
    tot_pts_inside_objmesh_labels = []
    nn_bsz = len(obj_mesh)
    for i_bsz in range(nn_bsz):
        cur_obj_mesh = obj_mesh[i_bsz]
        cur_subj_pts = subj_pts[i_bsz].detach().cpu().numpy()
        ori_subj_pts_shape = cur_subj_pts.shape
        if len(cur_subj_pts.shape) > 2:
            cur_subj_pts = cur_subj_pts.reshape(cur_subj_pts.shape[0] * cur_subj_pts.shape[1], 3)
        # 
        pts_inside_objmesh = cur_obj_mesh.contains(cur_subj_pts)
        pts_inside_objmesh = pts_inside_objmesh.astype(np.float32)
        ### reshape inside_objmesh labels ###
        pts_inside_objmesh = pts_inside_objmesh.reshape(*ori_subj_pts_shape[:-1])
        
        tot_pts_inside_objmesh_labels.append(pts_inside_objmesh)
    tot_pts_inside_objmesh_labels = np.stack(tot_pts_inside_objmesh_labels, axis=0) # nn_bsz x nn_subj_pts
    tot_pts_inside_objmesh_labels = torch.from_numpy(tot_pts_inside_objmesh_labels).float()
    return tot_pts_inside_objmesh_labels.to(subj_pts.device) # gt inside objmesh labels and to the pts device #

# TODO: other optimization strategies? e.g. sequential optimziation> #
def optimize_sampled_hand_joints_wobj(sampled_joints, rel_base_pts_to_joints, dists_base_pts_to_joints, base_pts, base_normals, obj_verts, obj_normals, obj_faces):
    # sampled_joints: bsz x ws x nnj x 3
    # signed distances 
    
    # smoothness 
    # tot_n_objs #
    tot_obj_trimeshes = get_obj_trimesh_list(obj_verts, obj_faces)
    
    ## TODO: write the collect function for object verts, normals, faces ##
    
    
    ### A simple penetration resolving strategy is as follows:
    #### 1) get vertices in the object; 2) get nearest base points (for simplicity); 3) project the vertex to the base point ####
    ## 1) for joints only; 
    ## 2) for vertices;
    ## 3) for vertices ##
    ## TODO: optimzie the resolvign strategy stated above ##
    
    bsz, ws, nnj = sampled_joints.shape[:3]
    device = sampled_joints.device
    coarse_lr = 0.1
    num_iters = 100 # if i_iter > 0 else 1 ## nn-coarse-iters for global transformations #
    mano_path = "./data/mano_models/mano/models"
    
    # obj_verts: bsz x nnobjverts x 
    
    base_pts_exp = base_pts.unsqueeze(1).repeat(1, ws, 1, 1).contiguous()
    base_normals_exp = base_normals.unsqueeze(1).repeat(1, ws, 1, 1).contiguous()
    
    signed_dist_e_coeff = 1.0
    signed_dist_e_coeff = 0.0
    
    
    ### start optimization ###
    # setup MANO layer
    mano_layer = ManoLayer(
        flat_hand_mean=True,
        side='right',
        mano_root=mano_path, # mano_path for the mano model #
        ncomps=24,
        use_pca=True,
        root_rot_mode='axisang',
        joint_rot_mode='axisang'
    ).to(device)
    
    ## random init variables ##
    beta_var = torch.randn([bsz, 10]).to(device)
    rot_var = torch.randn([bsz * ws, 3]).to(device)
    theta_var = torch.randn([bsz * ws, 24]).to(device)
    transl_var = torch.randn([bsz * ws, 3]).to(device)
    
    beta_var.requires_grad_()
    rot_var.requires_grad_()
    theta_var.requires_grad_()
    transl_var.requires_grad_()
    opt = optim.Adam([rot_var, transl_var], lr=coarse_lr)
    for i_iter in range(num_iters):
        opt.zero_grad()
        # mano_layer #
        hand_verts, hand_joints = mano_layer(torch.cat([rot_var, theta_var], dim=-1),
            beta_var.unsqueeze(1).repeat(1, ws, 1).view(-1, 10), transl_var)
        hand_verts = hand_verts.view(bsz, ws, 778, 3) * 0.001 ## bsz x ws x nn
        hand_joints = hand_joints.view(bsz, ws, -1, 3) * 0.001
        
        ### === e1 should be close to predicted values === ###
        # bsz x ws x nnj x nnb x 3 #
        rel_base_pts_to_hand_joints = hand_joints.unsqueeze(-2) - base_pts.unsqueeze(1).unsqueeze(1)
        # bs zx ws x nnj x nnb # 
        signed_dist_base_pts_to_hand_joints = torch.sum(
            rel_base_pts_to_hand_joints * base_normals.unsqueeze(1).unsqueeze(1), dim=-1
        )
        rel_e = torch.sum(
            (rel_base_pts_to_hand_joints - rel_base_pts_to_joints) ** 2, dim=-1
        ).mean()
        if dists_base_pts_to_joints is not None:
            dist_e = torch.sum(
                (signed_dist_base_pts_to_hand_joints - dists_base_pts_to_joints) ** 2, dim=-1
            ).mean()
        else:
            dist_e = torch.zeros((1,), dtype=torch.float32).to(device).mean()
        
        
        ### === e2 the signed distances to nearest points should not be negative to the neareste === ###
        ## base_pts: bsz x nn_base_pts x 3
        ## bsz x ws x nnj x 1 x 3 -- bsz x 1 x 1 x nnb x 3 ##
        ## bsz x ws x nnj x nnb ##
        
        ''' strategy 2: use all base pts, rel, dists for resolving '''
        # rel_base_pts_to_hand_joints: bsz x ws x nnj x nnb x 3 #
        signed_dist_mask = signed_dist_base_pts_to_hand_joints < 0.
        l2_dist_rel_joints_to_base_pts_mask = torch.sqrt(
            torch.sum(rel_base_pts_to_hand_joints ** 2, dim=-1)
        ) < 0.05
        signed_dist_mask = (signed_dist_mask.float() + l2_dist_rel_joints_to_base_pts_mask.float()) > 1.5
        dot_rel_with_normals = torch.sum(
            rel_base_pts_to_hand_joints * base_normals.unsqueeze(1).unsqueeze(1), dim=-1
        )
        signed_dist_mask = signed_dist_mask.detach() # detach the mask #

        # dot_rel_with_normals: bsz x ws x nnj x nnb
        avg_masks = (signed_dist_mask.float()).sum(dim=-1).mean()
        
        signed_dist_e = dot_rel_with_normals * signed_dist_base_pts_to_hand_joints
        signed_dist_e = torch.sum(
            signed_dist_e[signed_dist_mask]
        ) / torch.clamp(torch.sum(signed_dist_mask.float()), min=1e-5).item()
        ###### ====== get loss for signed distances ==== ###
        ''' strategy 2: use all base pts, rel, dists for resolving '''
        
        
        
        ''' strategy 1: use nearest base pts, rel, dists for resolving '''
        # dist_rhand_joints_to_base_pts = torch.sum(
        #     (hand_joints.unsqueeze(-2) - base_pts.unsqueeze(1).unsqueeze(1)) ** 2, dim=-1
        # )
        # # minn_dists_idxes: bsz x ws x nnj -->  
        # minn_dists_to_base_pts, minn_dists_idxes = torch.min(
        #     dist_rhand_joints_to_base_pts, dim=-1
        # )
        # # base_pts: bsz x nn_base_pts x 3 #
        # # base_pts: bsz x ws x nn_base_pts x 3 #
        # # bsz x ws x nnj 
        
        # # object verts and object faces #
        # ## other than the sampling process; not 
        # # bsz x ws x nnj x 3 ##
        # nearest_base_pts = batched_index_select_ours(
        #     base_pts_exp, indices=minn_dists_idxes, dim=2
        # )
        # # bsz x ws x nnj x 3 # # base normalse #
        # nearest_base_normals = batched_index_select_ours(
        #     base_normals_exp, indices=minn_dists_idxes, dim=2
        # )  
        # # bsz x ws x nnj x 3 #  # the nearest distance points may be of some ambiguous 
        # rel_joints_to_nearest_base_pts = hand_joints - nearest_base_pts
        # # bsz x ws x nnj #
        # signed_dist_joints_to_base_pts = torch.sum(
        #     rel_joints_to_nearest_base_pts * nearest_base_normals, dim=-1
        # )
        # # should not be negative
        # signed_dist_mask = signed_dist_joints_to_base_pts < 0.
        # l2_dist_rel_joints_to_nearest_base_pts_mask = torch.sqrt(
        #     torch.sum(rel_joints_to_nearest_base_pts ** 2, dim=-1)
        # ) < 0.05
        # signed_dist_mask = (signed_dist_mask.float() + l2_dist_rel_joints_to_nearest_base_pts_mask.float()) > 1.5
        # ### ==== mean of signed distances ==== ###
        # signed_dist_e = torch.sum( # penetration 
        #     -1.0 * signed_dist_joints_to_base_pts[signed_dist_mask]
        # ) / torch.clamp(
        #     torch.sum(signed_dist_mask.float()), min=1e-5
        # ).item()
        ''' strategy 1: use nearest base pts, rel, dists for resolving '''
        
        
        ## === e3 smoothness and prior losses === ##
        pose_smoothness_loss = F.mse_loss(theta_var.view(bsz, ws, -1)[:, 1:], theta_var.view(bsz, ws, -1)[:, :-1])
        shape_prior_loss = torch.mean(beta_var**2)
        pose_prior_loss = torch.mean(theta_var**2)
        ## === e3 smoothness and prior losses === ##
        
        ## === e4 hand joints should be close to sampled hand joints === ##
        dist_dec_jts_to_sampled_pts = torch.sum(
            (hand_joints - sampled_joints) ** 2, dim=-1
        ).mean()
        
        ### signed distance coeff -> the distance coeff #
        loss = pose_smoothness_loss * 0.05 + shape_prior_loss*0.001 + pose_prior_loss * 0.0001 + signed_dist_e * signed_dist_e_coeff + rel_e + dist_e + dist_dec_jts_to_sampled_pts
        
        loss.backward()
        opt.step()
        
        print('Iter {}: {}'.format(i_iter, loss.item()), flush=True)
        print('\tShape Prior Loss: {}'.format(shape_prior_loss.item()))
        print('\tPose Prior Loss: {}'.format(pose_prior_loss.item()))
        print('\tPose Smoothness Loss: {}'.format(pose_smoothness_loss.item()))
        print('\tsigned_dist_e Loss: {}'.format(signed_dist_e.item()))
        print('\trel_e Loss: {}'.format(rel_e.item()))
        print('\tdist_e Loss: {}'.format(dist_e.item()))
        print('\tdist_dec_jts_to_sampled_pts Loss: {}'.format(dist_dec_jts_to_sampled_pts.item()))
    
    fine_lr = 0.1
    num_iters = 1000
    opt = optim.Adam([rot_var, transl_var, beta_var, theta_var], lr=fine_lr)
    for i_iter in range(num_iters):
        opt.zero_grad()
        # mano_layer #
        hand_verts, hand_joints = mano_layer(torch.cat([rot_var, theta_var], dim=-1),
            beta_var.unsqueeze(1).repeat(1, ws, 1).view(-1, 10), transl_var)
        hand_verts = hand_verts.view(bsz, ws, 778, 3) * 0.001 ## bsz x ws x nn
        hand_joints = hand_joints.view(bsz, ws, -1, 3) * 0.001
        
        ### === e1 should be close to predicted values === ###
        # bsz x ws x nnj x nnb x 3 #
        rel_base_pts_to_hand_joints = hand_joints.unsqueeze(-2) - base_pts.unsqueeze(1).unsqueeze(1)
        # bs zx ws x nnj x nnb # 
        signed_dist_base_pts_to_hand_joints = torch.sum(
            rel_base_pts_to_hand_joints * base_normals.unsqueeze(1).unsqueeze(1), dim=-1
        )
        rel_e = torch.sum(
            (rel_base_pts_to_hand_joints - rel_base_pts_to_joints) ** 2, dim=-1
        ).mean()
        
        # dists_base_pts_to_joints ## dists_base_pts_to_joints ##
        if dists_base_pts_to_joints is not None: ## dists_base_pts_to_joints ##
            dist_e = torch.sum(
                (signed_dist_base_pts_to_hand_joints - dists_base_pts_to_joints) ** 2, dim=-1
            ).mean()
        else:
            dist_e = torch.zeros((1,), dtype=torch.float32).mean()
        
        
        ### === e2 the signed distances to nearest points should not be negative to the neareste === ###
        ## base_pts: bsz x nn_base_pts x 3
        ## bsz x ws x nnj x 1 x 3 -- bsz x 1 x 1 x nnb x 3 ##
        ## bsz x ws x nnj x nnb ##
        dist_rhand_joints_to_base_pts = torch.sum(
            (hand_joints.unsqueeze(-2) - base_pts.unsqueeze(1).unsqueeze(1)) ** 2, dim=-1
        )
        # minn_dists_idxes: bsz x ws x nnj -->  
        minn_dists_to_base_pts, minn_dists_idxes = torch.min(
            dist_rhand_joints_to_base_pts, dim=-1
        )
        # base_pts: bsz x nn_base_pts x 3 #
        # base_pts: bsz x ws x nn_base_pts x 3 #
        # bsz x ws x nnj 
        # base_pts_exp = base_pts.unsqueeze(1).repeat(1, ws, 1, 1).contiguous()
        # bsz x ws x nnj x 3 ##
        nearest_base_pts = batched_index_select_ours(
            base_pts_exp, indices=minn_dists_idxes, dim=2
        )
        # bsz x ws x nnj x 3 #
        nearest_base_normals = batched_index_select_ours(
            base_normals_exp, indices=minn_dists_idxes, dim=2
        )
        # bsz x ws x nnj x 3 #
        rel_joints_to_nearest_base_pts = hand_joints - nearest_base_pts
        # bsz x ws x nnj #
        signed_dist_joints_to_base_pts = torch.sum(
            rel_joints_to_nearest_base_pts * nearest_base_normals, dim=-1
        )
        # should not be negative
        signed_dist_mask = signed_dist_joints_to_base_pts < 0.
        l2_dist_rel_joints_to_nearest_base_pts_mask = torch.sqrt(
            torch.sum(rel_joints_to_nearest_base_pts ** 2, dim=-1)
        ) < 0.05
        signed_dist_mask = (signed_dist_mask.float() + l2_dist_rel_joints_to_nearest_base_pts_mask.float()) > 1.5
        
        ### ==== mean of signed distances ==== ###
        signed_dist_e = torch.sum(
            -1.0 * signed_dist_joints_to_base_pts[signed_dist_mask]
        ) / torch.clamp(
            torch.sum(signed_dist_mask.float()), min=1e-5
        ).item()
        
        ## === e3 smoothness and prior losses === ##
        pose_smoothness_loss = F.mse_loss(theta_var.view(bsz, ws, -1)[:, 1:], theta_var.view(bsz, ws, -1)[:, :-1])
        shape_prior_loss = torch.mean(beta_var**2)
        pose_prior_loss = torch.mean(theta_var**2)
        ## === e3 smoothness and prior losses === ##
        
        ## === e4 hand joints should be close to sampled hand joints === ##
        dist_dec_jts_to_sampled_pts = torch.sum(
            (hand_joints - sampled_joints) ** 2, dim=-1
        ).mean()
        
        loss = pose_smoothness_loss * 0.05 + shape_prior_loss*0.001 + pose_prior_loss * 0.0001 + signed_dist_e * signed_dist_e_coeff + rel_e + dist_e + dist_dec_jts_to_sampled_pts
        
        loss.backward()
        opt.step()
        
        print('Iter {}: {}'.format(i_iter, loss.item()), flush=True)
        print('\tShape Prior Loss: {}'.format(shape_prior_loss.item()))
        print('\tPose Prior Loss: {}'.format(pose_prior_loss.item()))
        print('\tPose Smoothness Loss: {}'.format(pose_smoothness_loss.item()))
        print('\tsigned_dist_e Loss: {}'.format(signed_dist_e.item()))
        print('\trel_e Loss: {}'.format(rel_e.item()))
        print('\tdist_e Loss: {}'.format(dist_e.item()))
        print('\tdist_dec_jts_to_sampled_pts Loss: {}'.format(dist_dec_jts_to_sampled_pts.item()))
    
    
    # tot_obj_trimeshes
    ### refine the optimization with signed energy ##
    signed_dist_e_coeff = 1.0 # 
    fine_lr = 0.1
    # num_iters = 1000 # 
    num_iters = 100 # reinement #
    opt = optim.Adam([rot_var, transl_var, beta_var, theta_var], lr=fine_lr)
    for i_iter in range(num_iters): # 
        opt.zero_grad()
        # mano_layer #
        hand_verts, hand_joints = mano_layer(torch.cat([rot_var, theta_var], dim=-1),
            beta_var.unsqueeze(1).repeat(1, ws, 1).view(-1, 10), transl_var)
        hand_verts = hand_verts.view(bsz, ws, 778, 3) * 0.001 ## bsz x ws x nn
        hand_joints = hand_joints.view(bsz, ws, -1, 3) * 0.001
        
        ### === e1 should be close to predicted values === ###
        # bsz x ws x nnj x nnb x 3 #
        rel_base_pts_to_hand_joints = hand_joints.unsqueeze(-2) - base_pts.unsqueeze(1).unsqueeze(1)
        # bs zx ws x nnj x nnb # 
        signed_dist_base_pts_to_hand_joints = torch.sum(
            rel_base_pts_to_hand_joints * base_normals.unsqueeze(1).unsqueeze(1), dim=-1
        )
        rel_e = torch.sum(
            (rel_base_pts_to_hand_joints - rel_base_pts_to_joints) ** 2, dim=-1
        ).mean()
        
        # dists_base_pts_to_joints ## dists_base_pts_to_joints ##
        if dists_base_pts_to_joints is not None: ## dists_base_pts_to_joints ##
            dist_e = torch.sum(
                (signed_dist_base_pts_to_hand_joints - dists_base_pts_to_joints) ** 2, dim=-1
            ).mean()
        else:
            dist_e = torch.zeros((1,), dtype=torch.float32).mean()
        
        
        ''' strategy 2: use all base pts, rel, dists for resolving '''
        # # rel_base_pts_to_hand_joints: bsz x ws x nnj x nnb x 3 #
        # signed_dist_mask = signed_dist_base_pts_to_hand_joints < 0.
        # l2_dist_rel_joints_to_base_pts_mask = torch.sqrt(
        #     torch.sum(rel_base_pts_to_hand_joints ** 2, dim=-1)
        # ) < 0.05
        # signed_dist_mask = (signed_dist_mask.float() + l2_dist_rel_joints_to_base_pts_mask.float()) > 1.5
        # ## === dot rel with normals === ##
        # # dot_rel_with_normals = torch.sum(
        # #     rel_base_pts_to_hand_joints * base_normals.unsqueeze(1).unsqueeze(1), dim=-1
        # # )
        # ## === dot rel with normals === ##
        # ## === dot rel with rel, strategy 3 === ##
        # dot_rel_with_normals = torch.sum(
        #     -1.0 * rel_base_pts_to_hand_joints * rel_base_pts_to_hand_joints, dim=-1
        # )
        # ## === dot rel with rel, strategy 3 === ##
        # signed_dist_mask = signed_dist_mask.detach() # detach the mask #

        # # dot_rel_with_normals: bsz x ws x nnj x nnb
        # avg_masks = (signed_dist_mask.float()).sum(dim=-1).mean()
        
        # signed_dist_e = dot_rel_with_normals * signed_dist_base_pts_to_hand_joints
        # signed_dist_e = torch.sum(
        #     signed_dist_e[signed_dist_mask]
        # ) / torch.clamp(torch.sum(signed_dist_mask.float()), min=1e-5).item()
        # ###### ====== get loss for signed distances ==== ###
        ''' strategy 2: use all base pts, rel, dists for resolving '''
        
        ## use all base pts ##
        
        {
        # hard projections for 
        ''' strategy 1: use nearest base pts, rel, dists for resolving '''
        # ### === e2 the signed distances to nearest points should not be negative to the neareste === ###
        # ## base_pts: bsz x nn_base_pts x 3
        # ## bsz x ws x nnj x 1 x 3 -- bsz x 1 x 1 x nnb x 3 ##
        # ## bsz x ws x nnj x nnb ##
        # dist_rhand_joints_to_base_pts = torch.sum(
        #     (hand_joints.unsqueeze(-2) - base_pts.unsqueeze(1).unsqueeze(1)) ** 2, dim=-1
        # )
        # # minn_dists_idxes: bsz x ws x nnj -->  
        # minn_dists_to_base_pts, minn_dists_idxes = torch.min(
        #     dist_rhand_joints_to_base_pts, dim=-1
        # )
        # 
        # # base_pts: bsz x nn_base_pts x 3 #
        # # base_pts: bsz x ws x nn_base_pts x 3 #
        # # bsz x ws x nnj 
        # # base_pts_exp = base_pts.unsqueeze(1).repeat(1, ws, 1, 1).contiguous()
        # # bsz x ws x nnj x 3 ##
        # nearest_base_pts = batched_index_select_ours(
        #     base_pts_exp, indices=minn_dists_idxes, dim=2
        # )
        # # bsz x ws x nnj x 3 #
        # nearest_base_normals = batched_index_select_ours(
        #     base_normals_exp, indices=minn_dists_idxes, dim=2
        # )
        # # bsz x ws x nnj x 3 #
        # rel_joints_to_nearest_base_pts = hand_joints - nearest_base_pts
        # # bsz x ws x nnj #
        # signed_dist_joints_to_base_pts = torch.sum(
        #     rel_joints_to_nearest_base_pts * nearest_base_normals, dim=-1
        # )
        # # should not be negative
        # signed_dist_mask = signed_dist_joints_to_base_pts < 0.
        # ## === luojisiwei and others === ##
        # # l2_dist_rel_joints_to_nearest_base_pts_mask = torch.sqrt(
        # #     torch.sum(rel_joints_to_nearest_base_pts ** 2, dim=-1)
        # # ) < 0.05 
        # ## === luojisiwei and others === ##
        # l2_dist_rel_joints_to_nearest_base_pts_mask = torch.sqrt(
        #     torch.sum(rel_joints_to_nearest_base_pts ** 2, dim=-1)
        # ) < 0.1
        # signed_dist_mask = (signed_dist_mask.float() + l2_dist_rel_joints_to_nearest_base_pts_mask.float()) > 1.5
        
        # ### ==== mean of signed distances ==== ###
        # # signed_dist_e = torch.sum(
        # #     -1.0 * signed_dist_joints_to_base_pts[signed_dist_mask]
        # # ) / torch.clamp(
        # #     torch.sum(signed_dist_mask.float()), min=1e-5
        # # ).item()
        
        # # signed_dist_joints_to_base_pts: bsz x ws x nnj # -> disstances 
        # signed_dist_joints_to_base_pts = signed_dist_joints_to_base_pts.detach()
        # # 
        
        ## penetraition resolving --- strategy 
        # dot_rel_with_normals = torch.sum(
        #     rel_joints_to_nearest_base_pts * nearest_base_normals, dim=-1
        # )
        # signed_dist_mask = signed_dist_mask.detach() # detach the mask #
        # # bsz x ws x nnj --> the loss term
        # ## signed distances 3 #### isgned distance 3 ###
        # ## dotrelwithnormals, ##
        # # # signed_dist_mask -> the distances 
        
        # # dot_rel_with_normals: bsz x ws x nnj x nnb
        # avg_masks = (signed_dist_mask.float()).sum(dim=-1).mean()
        
        
        # signed_dist_e = dot_rel_with_normals * signed_dist_joints_to_base_pts
        # signed_dist_e = torch.sum(
        #     signed_dist_e[signed_dist_mask]
        # ) / torch.clamp(torch.sum(signed_dist_mask.float()), min=1e-5).item()
        # ###### ====== get loss for signed distances ==== ###
        ''' strategy 1: use nearest base pts, rel, dists for resolving '''
        }
        
        # bsz x ws x nnj # --> objmesh insides pts labels 
        pts_inside_objmesh_labels = judge_penetrated_points(tot_obj_trimeshes, hand_joints)
        pts_inside_objmesh_labels_mask = pts_inside_objmesh_labels.bool()
        
        
        # {
        # hard projections for 
        ''' strategy 1: use nearest base pts, rel, dists for resolving '''
        ### === e2 the signed distances to nearest points should not be negative to the neareste === ###
        ## base_pts: bsz x nn_base_pts x 3
        ## bsz x ws x nnj x 1 x 3 -- bsz x 1 x 1 x nnb x 3 ##
        ## bsz x ws x nnj x nnb ##
        dist_rhand_joints_to_base_pts = torch.sum(
            (hand_joints.unsqueeze(-2) - base_pts.unsqueeze(1).unsqueeze(1)) ** 2, dim=-1
        )
        # minn_dists_idxes: bsz x ws x nnj -->  
        # base_pts
        minn_dists_to_base_pts, minn_dists_idxes = torch.min(
            dist_rhand_joints_to_base_pts, dim=-1
        )
        
        # base_pts: bsz x nn_base_pts x 3 #
        # base_pts: bsz x ws x nn_base_pts x 3 #
        # bsz x ws x nnj 
        # base_pts_exp = base_pts.unsqueeze(1).repeat(1, ws, 1, 1).contiguous()
        # bsz x ws x nnj x 3 ##
        # simple penetration ## 
        nearest_base_pts = batched_index_select_ours(
            base_pts_exp, indices=minn_dists_idxes, dim=2
        )
        # bsz x ws x nnj x 3 #
        nearest_base_normals = batched_index_select_ours(
            base_normals_exp, indices=minn_dists_idxes, dim=2
        )
        # bsz x ws x nnj x 3 #
        rel_joints_to_nearest_base_pts = hand_joints - nearest_base_pts
        # bsz x ws x nnj #
        # signed_dist_joints_to_base_pts = torch.sum(
        #     rel_joints_to_nearest_base_pts * nearest_base_normals, dim=-1
        # )
        # # should not be negative
        # signed_dist_mask = signed_dist_joints_to_base_pts < 0.
        ## === luojisiwei and others === ##
        # l2_dist_rel_joints_to_nearest_base_pts_mask = torch.sqrt(
        #     torch.sum(rel_joints_to_nearest_base_pts ** 2, dim=-1)
        # ) < 0.05 
        ## === luojisiwei and others === ##
        ##### ===== GET l2_distance mask ===== #####
        # l2_dist_rel_joints_to_nearest_base_pts_mask = torch.sqrt(
        #     torch.sum(rel_joints_to_nearest_base_pts ** 2, dim=-1)
        # ) < 0.1
        # signed_dist_mask = (signed_dist_mask.float() + l2_dist_rel_joints_to_nearest_base_pts_mask.float()) > 1.5
        ##### ===== GET l2_distance mask ===== #####
        
        ### ==== mean of signed distances ==== ###
        # signed_dist_e = torch.sum(
        #     -1.0 * signed_dist_joints_to_base_pts[signed_dist_mask]
        # ) / torch.clamp(
        #     torch.sum(signed_dist_mask.float()), min=1e-5
        # ).item()
        
        # signed_dist_joints_to_base_pts: bsz x ws x nnj # -> disstances 
        signed_dist_joints_to_base_pts = signed_dist_joints_to_base_pts.detach()
        # 
        
        # dot rel 
        # penetraition resolving --- strategy 
        # dot_rel_with_normals = torch.sum( # dot rhand joints with normals #
        #     rel_joints_to_nearest_base_pts * nearest_base_normals, dim=-1
        # )
        # 
        dot_rel_with_normals = torch.sum( # dot rhand joints with normals #
            -rel_joints_to_nearest_base_pts * rel_joints_to_nearest_base_pts, dim=-1
        )
        #### Get masks for penetrated joint points ####
        # signed_dist_mask = (signed_dist_mask.float() + pts_inside_objmesh_labels_mask.float()) > 1.5
        signed_dist_mask = pts_inside_objmesh_labels_mask
        # bsz x ws x nnj 
        signed_dist_mask = signed_dist_mask.detach() # detach the mask #
        # bsz x ws x nnj --> the loss term
        ## signed distances 3 #### isgned distance 3 ###
        ## dotrelwithnormals, ##
        # # signed_dist_mask -> the distances 
        
        # dot_rel_with_normals: bsz x ws x nnj x nnb # avg over windows and batches #
        avg_masks = (signed_dist_mask.float()).sum(dim=-1).mean()
        
        ## get singed distance energies ### ## projection ##
        # signed_dist_e = dot_rel_with_normals * signed_dist_joints_to_base_pts
        signed_dist_e = -1.0 * dot_rel_with_normals
        signed_dist_e = torch.sum(
            signed_dist_e[signed_dist_mask]
        ) / torch.clamp(torch.sum(signed_dist_mask.float()), min=1e-5).item()
        ###### ====== get loss for signed distances ==== ###
        ''' strategy 1: use nearest base pts, rel, dists for resolving '''
        # cannot mask in some caes 
        # change of isgned distances #
        
        
        # intersection spline 
        ## judeg whether inside the object and only project those one inside of the object
        ## === e3 smoothness and prior losses === ##
        pose_smoothness_loss = F.mse_loss(theta_var.view(bsz, ws, -1)[:, 1:], theta_var.view(bsz, ws, -1)[:, :-1])
        shape_prior_loss = torch.mean(beta_var**2)
        pose_prior_loss = torch.mean(theta_var**2)
        ## === e3 smoothness and prior losses === ##
        
        #### ==== sv_dict ==== ####
        sv_dict = {
            'pts_inside_objmesh_labels_mask': pts_inside_objmesh_labels_mask.detach().cpu().numpy(),
            'hand_joints': hand_joints.detach().cpu().numpy(),
            'obj_verts': [cur_verts.detach().cpu().numpy() for cur_verts in obj_verts],
            'obj_faces': [cur_faces.detach().cpu().numpy() for cur_faces in obj_faces],
            'base_pts': base_pts.detach().cpu().numpy(),
            'base_normals': base_normals.detach().cpu().numpy(), # bsz x nnb x 3 -> bsz x nnb x 3 -> base normals #
            'nearest_base_pts': nearest_base_pts.detach().cpu().numpy(), # bsz x ws x nnj x 3 # 
            'nearest_base_normals': nearest_base_normals.detach().cpu().numpy(), # bsz x ws x nnj x 3 --> base normals and pts 
        }
        # 
        sv_dict_folder = "./data/mdm/tmp_saving"
        os.makedirs(sv_dict_folder, exist_ok=True)
        sv_dict_fn = os.path.join(sv_dict_folder, f"optim_iter_{i_iter}.npy")
        np.save(sv_dict_fn, sv_dict)
        print(f"Obj and subj saved to {sv_dict_fn}")
        #### ==== sv_dict ==== ####
        
        ## === e4 hand joints should be close to sampled hand joints === ##
        dist_dec_jts_to_sampled_pts = torch.sum(
            (hand_joints - sampled_joints) ** 2, dim=-1
        ).mean()
        
        # shoudl take a 
        # how to proejct the jvertex
        # hwo to project the veretex
        # weighted sum of the projectiondirection
        # weights of each base point
        # atraction field -> should be able to learn the penetration resolving strategy 
        # stochestic penetration resolving strategy #
        
        
        loss = pose_smoothness_loss * 0.05 + shape_prior_loss*0.001 + pose_prior_loss * 0.0001 + signed_dist_e * signed_dist_e_coeff + rel_e + dist_e + dist_dec_jts_to_sampled_pts
        
        loss.backward()
        opt.step()
        
        print('Iter {}: {}'.format(i_iter, loss.item()), flush=True)
        print('\tShape Prior Loss: {}'.format(shape_prior_loss.item()))
        print('\tPose Prior Loss: {}'.format(pose_prior_loss.item()))
        print('\tPose Smoothness Loss: {}'.format(pose_smoothness_loss.item()))
        print('\tsigned_dist_e Loss: {}'.format(signed_dist_e.item()))
        print('\trel_e Loss: {}'.format(rel_e.item()))
        print('\tdist_e Loss: {}'.format(dist_e.item()))
        print('\tdist_dec_jts_to_sampled_pts Loss: {}'.format(dist_dec_jts_to_sampled_pts.item()))
        # avg_masks
        print('\tAvg masks: {}'.format(avg_masks.item()))
    
    
    
    ''' returning sampled_joints ''' 
    sampled_joints = hand_joints
    np.save("optimized_verts.npy", hand_verts.detach().cpu().numpy())
    print(f"Optimized verts saved to optimized_verts.npy")
    return sampled_joints.detach()


# TODO: other optimization strategies? e.g. sequential optimziation> #
def optimize_sampled_hand_joints_wobj_v2(sampled_joints, rel_base_pts_to_joints, dists_base_pts_to_joints, base_pts, base_normals, obj_verts, obj_normals, obj_faces):
    # sampled_joints: bsz x ws x nnj x 3 #
    # sampled_joints: bsz x ws x nnj x 3 # obj trimeshes #
    tot_obj_trimeshes = get_obj_trimesh_list(obj_verts, obj_faces)
    
    ## TODO: write the collect function for object verts, normals, faces ##
    
    ### A simple penetration resolving strategy is as follows:
    #### 1) get vertices in the object; 2) get nearest base points (for simplicity); 3) project the vertex to the base point ####
    ## 1) for joints only;
    ## 2) for vertices;
    ## 3) for vertices;
    ## TODO: optimzie the resolvign strategy stated above ##
    
    bsz, ws, nnj = sampled_joints.shape[:3]
    device = sampled_joints.device
    coarse_lr = 0.1
    num_iters = 100 # if i_iter > 0 else 1 ## nn-coarse-iters for global transformations #
    mano_path = "./data/mano_models/mano/models"
    
    # obj_verts: bsz x nnobjverts x 
    
    base_pts_exp = base_pts.unsqueeze(1).repeat(1, ws, 1, 1).contiguous()
    base_normals_exp = base_normals.unsqueeze(1).repeat(1, ws, 1, 1).contiguous()
    
    signed_dist_e_coeff = 1.0
    signed_dist_e_coeff = 0.0
    
    
    ### start optimization ###
    # setup MANO layer
    mano_layer = ManoLayer(
        flat_hand_mean=True,
        side='right',
        mano_root=mano_path, # mano_path for the mano model #
        ncomps=24,
        use_pca=True,
        root_rot_mode='axisang',
        joint_rot_mode='axisang'
    ).to(device)
    
    ## random init variables ##
    beta_var = torch.randn([bsz, 10]).to(device)
    rot_var = torch.randn([bsz * ws, 3]).to(device)
    theta_var = torch.randn([bsz * ws, 24]).to(device)
    transl_var = torch.randn([bsz * ws, 3]).to(device)
    
    beta_var.requires_grad_()
    rot_var.requires_grad_()
    theta_var.requires_grad_()
    transl_var.requires_grad_()
    opt = optim.Adam([rot_var, transl_var], lr=coarse_lr)
    for i_iter in range(num_iters):
        opt.zero_grad()
        # mano_layer #
        hand_verts, hand_joints = mano_layer(torch.cat([rot_var, theta_var], dim=-1),
            beta_var.unsqueeze(1).repeat(1, ws, 1).view(-1, 10), transl_var)
        hand_verts = hand_verts.view(bsz, ws, 778, 3) * 0.001 ## bsz x ws x nn
        hand_joints = hand_joints.view(bsz, ws, -1, 3) * 0.001
        
        ### === e1 should be close to predicted values === ###
        # bsz x ws x nnj x nnb x 3 #
        rel_base_pts_to_hand_joints = hand_joints.unsqueeze(-2) - base_pts.unsqueeze(1).unsqueeze(1)
        # bs zx ws x nnj x nnb # 
        signed_dist_base_pts_to_hand_joints = torch.sum(
            rel_base_pts_to_hand_joints * base_normals.unsqueeze(1).unsqueeze(1), dim=-1
        )
        rel_e = torch.sum(
            (rel_base_pts_to_hand_joints - rel_base_pts_to_joints) ** 2, dim=-1
        ).mean()
        if dists_base_pts_to_joints is not None:
            dist_e = torch.sum(
                (signed_dist_base_pts_to_hand_joints - dists_base_pts_to_joints) ** 2, dim=-1
            ).mean()
        else:
            dist_e = torch.zeros((1,), dtype=torch.float32).to(device).mean()
        
        
        ### === e2 the signed distances to nearest points should not be negative to the neareste === ###
        ## base_pts: bsz x nn_base_pts x 3
        ## bsz x ws x nnj x 1 x 3 -- bsz x 1 x 1 x nnb x 3 ##
        ## bsz x ws x nnj x nnb ##
        
        ''' strategy 2: use all base pts, rel, dists for resolving '''
        # rel_base_pts_to_hand_joints: bsz x ws x nnj x nnb x 3 #
        signed_dist_mask = signed_dist_base_pts_to_hand_joints < 0.
        l2_dist_rel_joints_to_base_pts_mask = torch.sqrt(
            torch.sum(rel_base_pts_to_hand_joints ** 2, dim=-1)
        ) < 0.05
        signed_dist_mask = (signed_dist_mask.float() + l2_dist_rel_joints_to_base_pts_mask.float()) > 1.5
        dot_rel_with_normals = torch.sum(
            rel_base_pts_to_hand_joints * base_normals.unsqueeze(1).unsqueeze(1), dim=-1
        )
        signed_dist_mask = signed_dist_mask.detach() # detach the mask #

        # dot_rel_with_normals: bsz x ws x nnj x nnb
        avg_masks = (signed_dist_mask.float()).sum(dim=-1).mean()
        
        signed_dist_e = dot_rel_with_normals * signed_dist_base_pts_to_hand_joints
        signed_dist_e = torch.sum(
            signed_dist_e[signed_dist_mask]
        ) / torch.clamp(torch.sum(signed_dist_mask.float()), min=1e-5).item()
        ###### ====== get loss for signed distances ==== ###
        ''' strategy 2: use all base pts, rel, dists for resolving '''
        
        
        
        ''' strategy 1: use nearest base pts, rel, dists for resolving '''
        # dist_rhand_joints_to_base_pts = torch.sum(
        #     (hand_joints.unsqueeze(-2) - base_pts.unsqueeze(1).unsqueeze(1)) ** 2, dim=-1
        # )
        # # minn_dists_idxes: bsz x ws x nnj -->  
        # minn_dists_to_base_pts, minn_dists_idxes = torch.min(
        #     dist_rhand_joints_to_base_pts, dim=-1
        # )
        # # base_pts: bsz x nn_base_pts x 3 #
        # # base_pts: bsz x ws x nn_base_pts x 3 #
        # # bsz x ws x nnj 
        
        # # object verts and object faces #
        # ## other than the sampling process; not 
        # # bsz x ws x nnj x 3 ##
        # nearest_base_pts = batched_index_select_ours(
        #     base_pts_exp, indices=minn_dists_idxes, dim=2
        # )
        # # bsz x ws x nnj x 3 # # base normalse #
        # nearest_base_normals = batched_index_select_ours(
        #     base_normals_exp, indices=minn_dists_idxes, dim=2
        # )  
        # # bsz x ws x nnj x 3 #  # the nearest distance points may be of some ambiguous 
        # rel_joints_to_nearest_base_pts = hand_joints - nearest_base_pts
        # # bsz x ws x nnj #
        # signed_dist_joints_to_base_pts = torch.sum(
        #     rel_joints_to_nearest_base_pts * nearest_base_normals, dim=-1
        # )
        # # should not be negative
        # signed_dist_mask = signed_dist_joints_to_base_pts < 0.
        # l2_dist_rel_joints_to_nearest_base_pts_mask = torch.sqrt(
        #     torch.sum(rel_joints_to_nearest_base_pts ** 2, dim=-1)
        # ) < 0.05
        # signed_dist_mask = (signed_dist_mask.float() + l2_dist_rel_joints_to_nearest_base_pts_mask.float()) > 1.5
        # ### ==== mean of signed distances ==== ###
        # signed_dist_e = torch.sum( # penetration 
        #     -1.0 * signed_dist_joints_to_base_pts[signed_dist_mask]
        # ) / torch.clamp(
        #     torch.sum(signed_dist_mask.float()), min=1e-5
        # ).item()
        ''' strategy 1: use nearest base pts, rel, dists for resolving '''
        
        
        ## === e3 smoothness and prior losses === ##
        pose_smoothness_loss = F.mse_loss(theta_var.view(bsz, ws, -1)[:, 1:], theta_var.view(bsz, ws, -1)[:, :-1])
        shape_prior_loss = torch.mean(beta_var**2)
        pose_prior_loss = torch.mean(theta_var**2)
        ## === e3 smoothness and prior losses === ##
        
        ## === e4 hand joints should be close to sampled hand joints === ##
        dist_dec_jts_to_sampled_pts = torch.sum(
            (hand_joints - sampled_joints) ** 2, dim=-1
        ).mean()
        
        ### signed distance coeff -> the distance coeff #
        loss = pose_smoothness_loss * 0.05 + shape_prior_loss*0.001 + pose_prior_loss * 0.0001 + signed_dist_e * signed_dist_e_coeff + rel_e + dist_e + dist_dec_jts_to_sampled_pts
        
        loss.backward()
        opt.step()
        
        print('Iter {}: {}'.format(i_iter, loss.item()), flush=True)
        print('\tShape Prior Loss: {}'.format(shape_prior_loss.item()))
        print('\tPose Prior Loss: {}'.format(pose_prior_loss.item()))
        print('\tPose Smoothness Loss: {}'.format(pose_smoothness_loss.item()))
        print('\tsigned_dist_e Loss: {}'.format(signed_dist_e.item()))
        print('\trel_e Loss: {}'.format(rel_e.item()))
        print('\tdist_e Loss: {}'.format(dist_e.item()))
        print('\tdist_dec_jts_to_sampled_pts Loss: {}'.format(dist_dec_jts_to_sampled_pts.item()))
    
    fine_lr = 0.1
    num_iters = 1000
    opt = optim.Adam([rot_var, transl_var, beta_var, theta_var], lr=fine_lr)
    for i_iter in range(num_iters):
        opt.zero_grad()
        # mano_layer #
        hand_verts, hand_joints = mano_layer(torch.cat([rot_var, theta_var], dim=-1),
            beta_var.unsqueeze(1).repeat(1, ws, 1).view(-1, 10), transl_var)
        hand_verts = hand_verts.view(bsz, ws, 778, 3) * 0.001 ## bsz x ws x nn
        hand_joints = hand_joints.view(bsz, ws, -1, 3) * 0.001
        
        ### === e1 should be close to predicted values === ###
        # bsz x ws x nnj x nnb x 3 #
        rel_base_pts_to_hand_joints = hand_joints.unsqueeze(-2) - base_pts.unsqueeze(1).unsqueeze(1)
        # bs zx ws x nnj x nnb # 
        signed_dist_base_pts_to_hand_joints = torch.sum(
            rel_base_pts_to_hand_joints * base_normals.unsqueeze(1).unsqueeze(1), dim=-1
        )
        rel_e = torch.sum(
            (rel_base_pts_to_hand_joints - rel_base_pts_to_joints) ** 2, dim=-1
        ).mean()
        
        # dists_base_pts_to_joints ## dists_base_pts_to_joints ##
        if dists_base_pts_to_joints is not None: ## dists_base_pts_to_joints ##
            dist_e = torch.sum(
                (signed_dist_base_pts_to_hand_joints - dists_base_pts_to_joints) ** 2, dim=-1
            ).mean()
        else:
            dist_e = torch.zeros((1,), dtype=torch.float32).mean()
        
        
        ### === e2 the signed distances to nearest points should not be negative to the neareste === ###
        ## base_pts: bsz x nn_base_pts x 3
        ## bsz x ws x nnj x 1 x 3 -- bsz x 1 x 1 x nnb x 3 ##
        ## bsz x ws x nnj x nnb ##
        dist_rhand_joints_to_base_pts = torch.sum(
            (hand_joints.unsqueeze(-2) - base_pts.unsqueeze(1).unsqueeze(1)) ** 2, dim=-1
        )
        # minn_dists_idxes: bsz x ws x nnj -->  
        minn_dists_to_base_pts, minn_dists_idxes = torch.min(
            dist_rhand_joints_to_base_pts, dim=-1
        )
        # base_pts: bsz x nn_base_pts x 3 #
        # base_pts: bsz x ws x nn_base_pts x 3 #
        # bsz x ws x nnj 
        # base_pts_exp = base_pts.unsqueeze(1).repeat(1, ws, 1, 1).contiguous()
        # bsz x ws x nnj x 3 ##
        nearest_base_pts = batched_index_select_ours(
            base_pts_exp, indices=minn_dists_idxes, dim=2
        )
        # bsz x ws x nnj x 3 #
        nearest_base_normals = batched_index_select_ours(
            base_normals_exp, indices=minn_dists_idxes, dim=2
        )
        # bsz x ws x nnj x 3 #
        rel_joints_to_nearest_base_pts = hand_joints - nearest_base_pts
        # bsz x ws x nnj #
        signed_dist_joints_to_base_pts = torch.sum(
            rel_joints_to_nearest_base_pts * nearest_base_normals, dim=-1
        )
        # should not be negative
        signed_dist_mask = signed_dist_joints_to_base_pts < 0.
        l2_dist_rel_joints_to_nearest_base_pts_mask = torch.sqrt(
            torch.sum(rel_joints_to_nearest_base_pts ** 2, dim=-1)
        ) < 0.05
        signed_dist_mask = (signed_dist_mask.float() + l2_dist_rel_joints_to_nearest_base_pts_mask.float()) > 1.5
        
        ### ==== mean of signed distances ==== ###
        signed_dist_e = torch.sum(
            -1.0 * signed_dist_joints_to_base_pts[signed_dist_mask]
        ) / torch.clamp(
            torch.sum(signed_dist_mask.float()), min=1e-5
        ).item()
        
        ## === e3 smoothness and prior losses === ##
        pose_smoothness_loss = F.mse_loss(theta_var.view(bsz, ws, -1)[:, 1:], theta_var.view(bsz, ws, -1)[:, :-1])
        shape_prior_loss = torch.mean(beta_var**2)
        pose_prior_loss = torch.mean(theta_var**2)
        ## === e3 smoothness and prior losses === ##
        
        ## === e4 hand joints should be close to sampled hand joints === ##
        dist_dec_jts_to_sampled_pts = torch.sum(
            (hand_joints - sampled_joints) ** 2, dim=-1
        ).mean()
        
        loss = pose_smoothness_loss * 0.05 + shape_prior_loss*0.001 + pose_prior_loss * 0.0001 + signed_dist_e * signed_dist_e_coeff + rel_e + dist_e + dist_dec_jts_to_sampled_pts
        
        loss.backward()
        opt.step()
        
        print('Iter {}: {}'.format(i_iter, loss.item()), flush=True)
        print('\tShape Prior Loss: {}'.format(shape_prior_loss.item()))
        print('\tPose Prior Loss: {}'.format(pose_prior_loss.item()))
        print('\tPose Smoothness Loss: {}'.format(pose_smoothness_loss.item()))
        print('\tsigned_dist_e Loss: {}'.format(signed_dist_e.item()))
        print('\trel_e Loss: {}'.format(rel_e.item()))
        print('\tdist_e Loss: {}'.format(dist_e.item()))
        print('\tdist_dec_jts_to_sampled_pts Loss: {}'.format(dist_dec_jts_to_sampled_pts.item()))
    
    
    # tot_obj_trimeshes
    ### refine the optimization with signed energy ##
    # 
    # signed_dist_jts_to_nearest_base_pts = []
    # tot_nearest_base_pts = []
    # tot_nearest_base_normals = []
    
    signed_dist_e_coeff = 1.0 # 
    fine_lr = 0.1
    # num_iters = 1000 # 
    num_iters = 100 # reinement #
    opt = optim.Adam([rot_var, transl_var, beta_var, theta_var], lr=fine_lr)
    for i_iter in range(num_iters): # 
        opt.zero_grad()
        # mano_layer #
        hand_verts, hand_joints = mano_layer(torch.cat([rot_var, theta_var], dim=-1),
            beta_var.unsqueeze(1).repeat(1, ws, 1).view(-1, 10), transl_var)
        hand_verts = hand_verts.view(bsz, ws, 778, 3) * 0.001 ## bsz x ws x nn
        hand_joints = hand_joints.view(bsz, ws, -1, 3) * 0.001
        
        ### === e1 should be close to predicted values === ###
        # bsz x ws x nnj x nnb x 3 #
        rel_base_pts_to_hand_joints = hand_joints.unsqueeze(-2) - base_pts.unsqueeze(1).unsqueeze(1)
        # bs zx ws x nnj x nnb # 
        signed_dist_base_pts_to_hand_joints = torch.sum(
            rel_base_pts_to_hand_joints * base_normals.unsqueeze(1).unsqueeze(1), dim=-1
        )
        rel_e = torch.sum(
            (rel_base_pts_to_hand_joints - rel_base_pts_to_joints) ** 2, dim=-1
        ).mean()
        
        # dists_base_pts_to_joints ## dists_base_pts_to_joints ##
        if dists_base_pts_to_joints is not None: ## dists_base_pts_to_joints ##
            dist_e = torch.sum(
                (signed_dist_base_pts_to_hand_joints - dists_base_pts_to_joints) ** 2, dim=-1
            ).mean()
        else:
            dist_e = torch.zeros((1,), dtype=torch.float32).mean()
        
        ### ==== inside the objemesh labels ==== ###
        # bsz x ws x nnj # --> objmesh insides pts labels #
        pts_inside_objmesh_labels = judge_penetrated_points(tot_obj_trimeshes, hand_joints)
        pts_inside_objmesh_labels_mask = pts_inside_objmesh_labels.bool()
        
        
        # {
        # hard projections for 
        ''' strategy 1: use nearest base pts, rel, dists for resolving '''
        ### === e2 the signed distances to nearest points should not be negative to the neareste === ###
        ## base_pts: bsz x nn_base_pts x 3
        ## bsz x ws x nnj x 1 x 3 -- bsz x 1 x 1 x nnb x 3 ##
        ## bsz x ws x nnj x nnb ##
        dist_rhand_joints_to_base_pts = torch.sum(
            (hand_joints.unsqueeze(-2) - base_pts.unsqueeze(1).unsqueeze(1)) ** 2, dim=-1
        )
        # minn_dists_idxes: bsz x ws x nnj #
        # base_pts
        minn_dists_to_base_pts, minn_dists_idxes = torch.min(
            dist_rhand_joints_to_base_pts, dim=-1
        )
        
        # base_pts: bsz x nn_base_pts x 3 #
        # base_pts: bsz x ws x nn_base_pts x 3 #
        # bsz x ws x nnj 
        # base_pts_exp = base_pts.unsqueeze(1).repeat(1, ws, 1, 1).contiguous()
        # bsz x ws x nnj x 3 ##
        # simple penetration ## 
        nearest_base_pts = batched_index_select_ours(
            base_pts_exp, indices=minn_dists_idxes, dim=2
        )
        # bsz x ws x nnj x 3 # # 
        nearest_base_normals = batched_index_select_ours(
            base_normals_exp, indices=minn_dists_idxes, dim=2
        )
        tot_masks = []
        tot_base_pts = []
        tot_base_normals = []
        tot_base_signed_dists = []
        ## === nearest base pts === ##
        for i_bsz in range(nearest_base_pts.size(0)):
            # masks, base_pts, base_normals for each frame here
            # cur_bsz_
            cur_bsz_masks = [pts_inside_objmesh_labels_mask[i_bsz][0]]
            cur_bsz_base_pts = [nearest_base_pts[i_bsz][0]]
            cur_bsz_base_normals = [nearest_base_normals[i_bsz][0]]
            # nnjts #
            ## st frame signed dist ##
            cur_bsz_st_frame_signed_dist = torch.sum(
                (hand_joints[i_bsz][0] - cur_bsz_base_pts[0]) * cur_bsz_base_normals[0], dim=-1
            )
            cur_bsz_signed_dist = [cur_bsz_st_frame_signed_dist]
            for i_fr in range(1, nearest_base_pts.size(1)):
                cur_bsz_cur_fr_jts = hand_joints[i_bsz][i_fr]
                # cur_bsz_cur_fr_base_pts = nearest_base_pts
                # cur_fr_jts - 
                cur_bsz_cur_fr_prev_fr_signed_dist = torch.sum(
                    (cur_bsz_cur_fr_jts - cur_bsz_base_pts[-1]) * cur_bsz_base_normals[-1], dim=-1
                )
                # nnjts # cur
                cur_bsz_cur_fr_mask = ((cur_bsz_signed_dist[-1] >= 0.).float() + (cur_bsz_cur_fr_prev_fr_signed_dist < 0.).float()) > 1.5
                cur_bsz_cur_fr_base_pts = nearest_base_pts[i_bsz][i_fr].clone()
                cur_bsz_cur_fr_base_pts[cur_bsz_cur_fr_mask] = cur_bsz_base_pts[-1][cur_bsz_cur_fr_mask]
                cur_bsz_cur_fr_base_normals = nearest_base_normals[i_bsz][i_fr].clone()
                # ### curbsz curfr base normals; ### #
                cur_bsz_cur_fr_base_normals[cur_bsz_cur_fr_mask] = cur_bsz_base_normals[-1][cur_bsz_cur_fr_mask]
                cur_bsz_cur_fr_signed_dist = torch.sum(
                    (cur_bsz_cur_fr_jts - cur_bsz_cur_fr_base_pts) * cur_bsz_cur_fr_base_normals, dim=-1
                )
                cur_bsz_cur_fr_signed_dist[cur_bsz_cur_fr_mask] = 0. # ot the bes points
                ### for masks ###
                cur_bsz_masks.append(cur_bsz_cur_fr_mask)
                cur_bsz_base_pts.append(cur_bsz_cur_fr_base_pts)
                cur_bsz_base_normals.append(cur_bsz_cur_fr_base_normals)
            # 
            cur_bsz_masks = torch.stack(cur_bsz_masks, dim=0)
            cur_bsz_base_pts = torch.stack(cur_bsz_base_pts, dim=0)
            cur_bsz_base_normals = torch.stack(cur_bsz_base_normals, dim=0)
            cur_bsz_signed_dist = torch.stack(cur_bsz_signed_dist, dim=0)
            tot_masks.append(cur_bsz_masks)
            tot_base_pts.append(cur_bsz_base_pts)
            tot_base_normals.append(cur_bsz_base_normals)
            tot_base_signed_dists.append(cur_bsz_signed_dist)
        # masks; 
        tot_masks = torch.stack(tot_masks, dim=0)
        tot_base_pts = torch.stack(tot_base_pts, dim=0)
        tot_base_normals = torch.stack(tot_base_normals, dim=0)
        tot_base_signed_dists = torch.stack(tot_base_signed_dists, dim=0)
        
        # 
        nearest_base_pts = tot_base_pts.clone() # tot base pts 
        nearest_base_normals = tot_base_normals.clone()
        pts_inside_objmesh_labels_mask = tot_masks.clone()
        
        # if len()
        # bsz x ws x nnj x 3 #
        rel_joints_to_nearest_base_pts = hand_joints - nearest_base_pts
        
        # signed_dist_joints_to_base_pts: bsz x ws x nnj # -> disstances 
        # signed_dist_joints_to_base_pts = signed_dist_joints_to_base_pts.detach()
        # 
        
        # dot rel 
        # penetraition resolving --- strategy 
        # dot_rel_with_normals = torch.sum( # dot rhand joints with normals #
        #     rel_joints_to_nearest_base_pts * nearest_base_normals, dim=-1
        # )
        # 
        dot_rel_with_normals = torch.sum( # dot rhand joints with normals #
            -rel_joints_to_nearest_base_pts * rel_joints_to_nearest_base_pts, dim=-1
        )
        #### Get masks for penetrated joint points ####
        # signed_dist_mask = (signed_dist_mask.float() + pts_inside_objmesh_labels_mask.float()) > 1.5
        signed_dist_mask = pts_inside_objmesh_labels_mask
        # bsz x ws x nnj 
        signed_dist_mask = signed_dist_mask.detach() # detach the mask #
        # bsz x ws x nnj --> the loss term
        ## signed distances 3 #### isgned distance 3 ###
        ## dotrelwithnormals, ##
        # # signed_dist_mask -> the distances 
        
        # dot_rel_with_normals: bsz x ws x nnj x nnb # avg over windows and batches #
        avg_masks = (signed_dist_mask.float()).sum(dim=-1).mean()
        
        ## get singed distance energies ### ## projection ##
        # signed_dist_e = dot_rel_with_normals * signed_dist_joints_to_base_pts
        ### dot_rel_with_normals --> 
        signed_dist_e = -1.0 * dot_rel_with_normals
        signed_dist_e = torch.sum(
            signed_dist_e[signed_dist_mask]
        ) / torch.clamp(torch.sum(signed_dist_mask.float()), min=1e-5).item()
        ###### ====== get loss for signed distances ==== ###
        ''' strategy 1: use nearest base pts, rel, dists for resolving '''
        # cannot mask in some caes 
        # change of isgned distances #
        
        
        # intersection spline 
        ## judeg whether inside the object and only project those one inside of the object
        ## === e3 smoothness and prior losses === ##
        pose_smoothness_loss = F.mse_loss(theta_var.view(bsz, ws, -1)[:, 1:], theta_var.view(bsz, ws, -1)[:, :-1])
        shape_prior_loss = torch.mean(beta_var**2)
        pose_prior_loss = torch.mean(theta_var**2)
        ## === e3 smoothness and prior losses === ##
        
        # points to object vertices 
        #### ==== sv_dict ==== ####
        sv_dict = {
            'pts_inside_objmesh_labels_mask': pts_inside_objmesh_labels_mask.detach().cpu().numpy(),
            'hand_joints': hand_joints.detach().cpu().numpy(),
            
            'obj_verts': [cur_verts.detach().cpu().numpy() for cur_verts in obj_verts],
            'obj_faces': [cur_faces.detach().cpu().numpy() for cur_faces in obj_faces],
            
            'base_pts': base_pts.detach().cpu().numpy(),
            'base_normals': base_normals.detach().cpu().numpy(), # bsz x nnb x 3 -> bsz x nnb x 3 -> base normals #
            'nearest_base_pts': nearest_base_pts.detach().cpu().numpy(), # bsz x ws x nnj x 3 # 
            'nearest_base_normals': nearest_base_normals.detach().cpu().numpy(), # bsz x ws x nnj x 3 --> base normals and pts 
        }
        # 
        sv_dict_folder = "./data/mdm/tmp_saving"
        os.makedirs(sv_dict_folder, exist_ok=True)
        sv_dict_fn = os.path.join(sv_dict_folder, f"optim_iter_{i_iter}.npy")
        np.save(sv_dict_fn, sv_dict)
        print(f"Obj and subj saved to {sv_dict_fn}")
        #### ==== sv_dict ==== ####
        
        ## === e4 hand joints should be close to sampled hand joints === ##
        dist_dec_jts_to_sampled_pts = torch.sum(
            (hand_joints - sampled_joints) ** 2, dim=-1
        ).mean()
        
        # shoudl take a 
        # how to proejct the jvertex
        # hwo to project the veretex
        # weighted sum of the projectiondirection
        # weights of each base point
        # atraction field -> should be able to learn the penetration resolving strategy 
        # stochestic penetration resolving strategy #
        
        
        loss = pose_smoothness_loss * 0.05 + shape_prior_loss*0.001 + pose_prior_loss * 0.0001 + signed_dist_e * signed_dist_e_coeff + rel_e + dist_e + dist_dec_jts_to_sampled_pts
        
        loss.backward()
        opt.step()
        
        print('Iter {}: {}'.format(i_iter, loss.item()), flush=True)
        print('\tShape Prior Loss: {}'.format(shape_prior_loss.item()))
        print('\tPose Prior Loss: {}'.format(pose_prior_loss.item()))
        print('\tPose Smoothness Loss: {}'.format(pose_smoothness_loss.item()))
        print('\tsigned_dist_e Loss: {}'.format(signed_dist_e.item()))
        print('\trel_e Loss: {}'.format(rel_e.item()))
        print('\tdist_e Loss: {}'.format(dist_e.item()))
        print('\tdist_dec_jts_to_sampled_pts Loss: {}'.format(dist_dec_jts_to_sampled_pts.item()))
        # avg_masks
        print('\tAvg masks: {}'.format(avg_masks.item()))
    
    
    
    ''' returning sampled_joints ''' 
    sampled_joints = hand_joints
    np.save("optimized_verts.npy", hand_verts.detach().cpu().numpy())
    print(f"Optimized verts saved to optimized_verts.npy")
    return sampled_joints.detach()

    

## 
def create_gaussian_diffusion(args): ## create guassian diffusion ##
    # default params
    predict_xstart = True  # we always predict x_start (a.k.a. x0), that's our deal!
    steps = 1000 # 
    scale_beta = 1.  # no scaling
    timestep_respacing = ''  # can be used for ddim sampling, we don't use it.
    learn_sigma = False # learn sigma #
    rescale_timesteps = False

    ## noose schedule; steps; scale_beta ## ## MSE ##
    betas = gd.get_named_beta_schedule(args.noise_schedule, steps, scale_beta)
    loss_type = gd.LossType.MSE

    if not timestep_respacing:
        timestep_respacing = [steps]

    print(f"dataset: {args.dataset}, rep_type: {args.rep_type}")
    if args.dataset in ['motion_ours'] and args.rep_type in ["obj_base_rel_dist", "ambient_obj_base_rel_dist"]:
        print(f"here! dataset: {args.dataset}, rep_type: {args.rep_type}")
        cur_spaced_diffusion_model = SpacedDiffusion_Ours
    # SpacedDiffusion_OursV2
    elif args.dataset in ['motion_ours'] and args.rep_type in ["obj_base_rel_dist_we"]:
        cur_spaced_diffusion_model = SpacedDiffusion_OursV2
    elif args.dataset in ['motion_ours'] and args.rep_type in ["obj_base_rel_dist_we_wj"]:
        cur_spaced_diffusion_model = SpacedDiffusion_OursV3
    # SpacedDiffusion_OursV4 
    elif args.dataset in ['motion_ours'] and args.rep_type in ["obj_base_rel_dist_we_wj_latents"]:
        if args.diff_joint_quants:
            cur_spaced_diffusion_model = SpacedDiffusion_OursV7
        elif args.diff_hand_params:
            cur_spaced_diffusion_model = SpacedDiffusion_OursV9
        else:
            if args.diff_spatial:
                cur_spaced_diffusion_model = SpacedDiffusion_OursV5
            elif args.diff_latents:
                cur_spaced_diffusion_model = SpacedDiffusion_OursV6
            else:
                cur_spaced_diffusion_model = SpacedDiffusion_OursV4
    else:
        cur_spaced_diffusion_model = SpacedDiffusion
    ### ==== predict xstart other than the noise in the model === ###
    return cur_spaced_diffusion_model(
        use_timesteps=space_timesteps(steps, timestep_respacing),
        betas=betas,
        model_mean_type=(
            gd.ModelMeanType.EPSILON if not predict_xstart else gd.ModelMeanType.START_X
        ),
        model_var_type=( ## use fixed sigmas / variances ##
            (
                gd.ModelVarType.FIXED_LARGE
                if not args.sigma_small
                else gd.ModelVarType.FIXED_SMALL # fixed small #
            )
            if not learn_sigma ## use learned sigmas ##
            else gd.ModelVarType.LEARNED_RANGE
        ), ## modelvartype ##
        loss_type=loss_type, ## loss_type ##
        rescale_timesteps=rescale_timesteps,
        lambda_vel=args.lambda_vel,
        lambda_rcxyz=args.lambda_rcxyz, ## lambda
        lambda_fc=args.lambda_fc,
        # motion_to_rep
        denoising_stra=args.denoising_stra,
        inter_optim=args.inter_optim,
        args=args,
    )
    
### from decoded energies to optimized joints ###
## latent variables ##
## encoded energies ## from energies calculated from perturbed energies ##
## decoded energies should also match the clean energy term ##


## and those values should be all denormed ##
def optimize_joints_according_to_e(dec_joints, base_pts, base_normals, dec_e):
    # dec_e_along_normals: bsz x (ws - 1) x nnj x nnb
    dec_e_along_normals = dec_e['dec_e_along_normals']
    # dec_e_vt_normals: bsz x (ws - 1) x nnj x nnb
    dec_e_vt_normals  = dec_e['dec_e_vt_normals']
    
    nn_iters = 10
    coarse_lr = 0.001
    
    dec_joints.requires_grad_()
    opt = optim.Adam([dec_joints], lr=coarse_lr)
    
    for i_iter in range(nn_iters):
        # dec_joints: bsz x ws x nnj x 3 
        # base_pts: bsz x nnb x 3
        k_f = 1.
        # bsz x ws x nnj x nnb x 3 #
        denormed_rel_base_pts_to_rhand_joints = dec_joints.unsqueeze(-2) - base_pts.unsqueeze(1).unsqueeze(1)
        
        k_f = 1. ## l2 rel base pts to pert rhand joints ##
        # l2_rel_base_pts_to_pert_rhand_joints: bsz x nf x nnj x nnb #
        l2_rel_base_pts_to_pert_rhand_joints = torch.norm(denormed_rel_base_pts_to_rhand_joints, dim=-1)
        ### att_forces ##
        att_forces = torch.exp(-k_f * l2_rel_base_pts_to_pert_rhand_joints) # bsz x nf x nnj x nnb #
        # bsz x (ws - 1) x nnj x nnb #
        att_forces = att_forces[:, :-1, :, :] # attraction forces -1 #
        # rhand_joints: ws x nnj x 3 # -> (ws - 1) x nnj x 3 ## rhand_joints ##
        # bsz x (ws - 1) x nnj x 3 --> displacements s#
        denormed_rhand_joints_disp = dec_joints[:, 1:, :, :] - dec_joints[:, :-1, :, :]

        # distance -- base_normalss,; (ws - 1) x nnj x nnb x 3 --> bsz x (ws - 1) x nnj x nnb # 
        # signed_dist_base_pts_to_pert_rhand_joints_along_normal # bsz x (ws - 1) x nnj x nnb #
        signed_dist_base_pts_to_rhand_joints_along_normal = torch.sum(
            base_normals.unsqueeze(1).unsqueeze(1) * denormed_rhand_joints_disp.unsqueeze(-2), dim=-1
        )
        # rel_base_pts_to_pert_rhand_joints_vt_normal: bsz x (ws -1) x nnj x nnb x 3 -> the relative positions vertical to base normals #
        rel_base_pts_to_rhand_joints_vt_normal = denormed_rhand_joints_disp.unsqueeze(-2)  - signed_dist_base_pts_to_rhand_joints_along_normal.unsqueeze(-1) * base_normals.unsqueeze(1).unsqueeze(1)
        dist_base_pts_to_rhand_joints_vt_normal = torch.sqrt(torch.sum(
            rel_base_pts_to_rhand_joints_vt_normal ** 2, dim=-1
        ))
        k_a = 1.
        k_b = 1.
        
        ### bsz x (ws - 1) x nnj x nnb ###
        e_disp_rel_to_base_along_normals = k_a * att_forces * torch.abs(signed_dist_base_pts_to_rhand_joints_along_normal)
        # (ws - 1) x nnj x nnb # -> dist vt normals # ## 
        e_disp_rel_to_baes_vt_normals = k_b * att_forces * dist_base_pts_to_rhand_joints_vt_normal
        # nf x nnj x nnb ---> dist_vt_normals -> nf x nnj x nnb # # torch.sqrt() ##
        # 
        loss_cur_e_pred_e_along_normals = ((e_disp_rel_to_base_along_normals - dec_e_along_normals) ** 2).mean()
        loss_cur_e_pred_e_vt_normals = ((e_disp_rel_to_baes_vt_normals - dec_e_vt_normals) ** 2).mean()
        
        loss = loss_cur_e_pred_e_along_normals + loss_cur_e_pred_e_vt_normals
        
        opt.zero_grad()
        loss.backward()
        opt.step()
        
        print('Iter {}: {}'.format(i_iter, loss.item()), flush=True)
        print('\tloss_cur_e_pred_e_along_normals: {}'.format(loss_cur_e_pred_e_along_normals.item()))
        print('\tloss_cur_e_pred_e_vt_normals: {}'.format(loss_cur_e_pred_e_vt_normals.item()))
    return dec_joints.detach()