# This code is based on https://github.com/openai/guided-diffusion
"""
This code started out as a PyTorch port of Ho et al's diffusion models:
https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/diffusion_utils_2.py

Docstrings have been added, as well as DDIM sampling and a new collection of beta schedules.
"""

import enum
import math

import numpy as np
import torch
import torch as th
from copy import deepcopy
from torch import optim, nn
from diffusion.nn import mean_flat, sum_flat
from diffusion.losses import normal_kl, discretized_gaussian_log_likelihood
from data_loaders.humanml.scripts import motion_process
import utils.model_util as model_util
# # # obj_verts, obj_faces
import utils.model_utils as model_utils
import utils.common_utils as common_utils
from manopth.manolayer import ManoLayer
from sample.reconstruct_data import calculate_disp_quants_batched, calculate_disp_quants_batched_v2


def get_named_beta_schedule(schedule_name, num_diffusion_timesteps, scale_betas=1.):
    """
    Get a pre-defined beta schedule for the given name.

    The beta schedule library consists of beta schedules which remain similar
    in the limit of num_diffusion_timesteps.
    Beta schedules may be added, but should not be removed or changed once
    they are committed to maintain backwards compatibility.
    """
    if schedule_name == "linear":
        # Linear schedule from Ho et al, extended to work for any number of
        # diffusion steps. # scale 
        scale = scale_betas * 1000 / num_diffusion_timesteps
        beta_start = scale * 0.0001
        beta_end = scale * 0.02
        return np.linspace(
            beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64
        )
    elif schedule_name == "cosine":
        return betas_for_alpha_bar(
            num_diffusion_timesteps,
            lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2,
        )
    else:
        raise NotImplementedError(f"unknown beta schedule: {schedule_name}")

## betas for alpha bar ##
def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):
    """
    Create a beta schedule that discretizes the given alpha_t_bar function,
    which defines the cumulative product of (1-beta) over time from t = [0,1].

    :param num_diffusion_timesteps: the number of betas to produce.
    :param alpha_bar: a lambda that takes an argument t from 0 to 1 and
                      produces the cumulative product of (1-beta) up to that
                      part of the diffusion process.
    :param max_beta: the maximum beta to use; use values lower than 1 to
                     prevent singularities.
    """
    betas = []
    for i in range(num_diffusion_timesteps):
        t1 = i / num_diffusion_timesteps
        t2 = (i + 1) / num_diffusion_timesteps
        betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
    return np.array(betas)


class ModelMeanType(enum.Enum):
    """
    Which type of output the model predicts.
    """

    PREVIOUS_X = enum.auto()  # the model predicts x_{t-1}
    START_X = enum.auto()  # the model predicts x_0
    EPSILON = enum.auto()  # the model predicts epsilon


class ModelVarType(enum.Enum):
    """
    What is used as the model's output variance.

    The LEARNED_RANGE option has been added to allow the model to predict
    values between FIXED_SMALL and FIXED_LARGE, making its job easier.
    """

    LEARNED = enum.auto()
    FIXED_SMALL = enum.auto()
    FIXED_LARGE = enum.auto()
    LEARNED_RANGE = enum.auto()


class LossType(enum.Enum):
    MSE = enum.auto()  # use raw MSE loss (and KL when learning variances)
    RESCALED_MSE = (
        enum.auto()
    )  # use raw MSE loss (with RESCALED_KL when learning variances)
    KL = enum.auto()  # use the variational lower-bound
    RESCALED_KL = enum.auto()  # like KL, but rescale to estimate the full VLB

    def is_vb(self):
        return self == LossType.KL or self == LossType.RESCALED_KL

class VarianceSchedule(torch.nn.Module):

    def __init__(self, num_steps, betas):
        super().__init__()
        # assert mode in ('linear', )
        self.num_steps = num_steps
        ## variance schedule ##
        # self.beta_1 = beta_1
        # self.beta_T = beta_T
        # self.mode = mode
        ## beta_1 = 1e-4 -> very small variance
        ## beta_T = 0.02 -> large variance
        # if mode == 'linear':
        #     betas = torch.linspace(beta_1, beta_T, steps=num_steps)

        print(f"betas: {betas.size()}, betas_0: {betas[0]}, betas_T: {betas[-1]}")
        # betas = torch.cat([torch.zeros([1]), betas], dim=0)     # zero variance, Padding betas --> 
        betas = betas.clone()

        alphas = 1 - betas
        log_alphas = torch.log(alphas) 
        for i in range(1, log_alphas.size(0)):  # 1 to T ## 1 to T ## variacne schedual #
            log_alphas[i] += log_alphas[i - 1]
        alpha_bars = log_alphas.exp()

        sigmas_flex = torch.sqrt(betas)
        sigmas_inflex = torch.zeros_like(sigmas_flex)
        for i in range(1, sigmas_flex.size(0)): # sigma inflex 
            sigmas_inflex[i] = ((1 - alpha_bars[i-1]) / (1 - alpha_bars[i])) * betas[i]
        sigmas_inflex = torch.sqrt(sigmas_inflex)

        self.register_buffer('betas', betas)
        self.register_buffer('alphas', alphas)
        self.register_buffer('alpha_bars', alpha_bars)
        self.register_buffer('sigmas_flex', sigmas_flex)
        self.register_buffer('sigmas_inflex', sigmas_inflex)

    def uniform_sample_t(self, batch_size):
        ts = np.random.choice(np.arange(1, self.num_steps+1), batch_size)
        return ts.tolist()

    def get_sigmas(self, t, flexibility):
        assert 0 <= flexibility and flexibility <= 1
        sigmas = self.sigmas_flex[t] * flexibility + self.sigmas_inflex[t] * (1 - flexibility)
        return sigmas


### for temporaldiff ###
class GaussianDiffusionV4:
    """
    Utilities for training and sampling diffusion models.

    Ported directly from here, and then adapted over time to further experimentation.
    https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/diffusion_utils_2.py#L42

    :param betas: a 1-D numpy array of betas for each diffusion timestep, ## pass floating point timesteps ##
                  starting at T and going to 1.
    :param model_mean_type: a ModelMeanType determining what the model outputs.
    :param model_var_type: a ModelVarType determining how variance is output.
    :param loss_type: a LossType determining the loss function to use.
    :param rescale_timesteps: if True, pass floating point timesteps into the
                              model so that they are always scaled like in the
                              original paper (0 to 1000).
    """

    def __init__(
        self,
        *,
        betas,
        model_mean_type,
        model_var_type,
        loss_type,
        rescale_timesteps=False,
        lambda_rcxyz=0.,
        lambda_vel=0.,
        lambda_pose=1.,
        lambda_orient=1.,
        lambda_loc=1.,
        data_rep='rot6d',
        lambda_root_vel=0.,
        lambda_vel_rcxyz=0.,
        lambda_fc=0.,
        denoising_stra="rep",
        inter_optim=False,
        args=None,
    ):
        self.model_mean_type = model_mean_type
        self.model_var_type = model_var_type ## model var type ##
        self.loss_type = loss_type
        self.rescale_timesteps = rescale_timesteps
        self.data_rep = data_rep
        
        self.args = args # possibly None
        
        ### GET the diff. suit ###
        self.diff_jts = self.args.diff_jts
        self.diff_basejtsrel = self.args.diff_basejtsrel
        self.diff_basejtse = self.args.diff_basejtse
        ### GET the diff. suit ###

        if data_rep != 'rot_vel' and lambda_pose != 1.:
            raise ValueError('lambda_pose is relevant only when training on velocities!')
        self.lambda_pose = lambda_pose
        self.lambda_orient = lambda_orient
        self.lambda_loc = lambda_loc

        self.lambda_rcxyz = lambda_rcxyz
        self.lambda_vel = lambda_vel
        self.lambda_root_vel = lambda_root_vel
        self.lambda_vel_rcxyz = lambda_vel_rcxyz
        self.lambda_fc = lambda_fc
        
        ### === denoising_stra for the denoising process === ###
        self.denoising_stra = denoising_stra
        self.inter_optim = inter_optim

        if self.lambda_rcxyz > 0. or self.lambda_vel > 0. or self.lambda_root_vel > 0. or \
                self.lambda_vel_rcxyz > 0. or self.lambda_fc > 0.:
            assert self.loss_type == LossType.MSE, 'Geometric losses are supported by MSE loss type only!'

        # Use float64 for accuracy.
        betas = np.array(betas, dtype=np.float64)
        self.betas = betas
        assert len(betas.shape) == 1, "betas must be 1-D" ## betas 
        assert (betas > 0).all() and (betas <= 1).all()

        self.num_timesteps = int(betas.shape[0])

        alphas = 1.0 - betas
        self.alphas_cumprod = np.cumprod(alphas, axis=0)
        self.alphas_cumprod_prev = np.append(1.0, self.alphas_cumprod[:-1]) ## alphas cumprod
        self.alphas_cumprod_next = np.append(self.alphas_cumprod[1:], 0.0)
        assert self.alphas_cumprod_prev.shape == (self.num_timesteps,)

        # calculations for diffusion q(x_t | x_{t-1}) and others
        self.sqrt_alphas_cumprod = np.sqrt(self.alphas_cumprod)
        self.sqrt_one_minus_alphas_cumprod = np.sqrt(1.0 - self.alphas_cumprod)
        self.log_one_minus_alphas_cumprod = np.log(1.0 - self.alphas_cumprod)
        self.sqrt_recip_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod)
        self.sqrt_recipm1_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod - 1)

        # calculations for posterior q(x_{t-1} | x_t, x_0)
        self.posterior_variance = (
            betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
        )
        # log calculation clipped because the posterior variance is 0 at the
        # beginning of the diffusion chain.
        self.posterior_log_variance_clipped = np.log(
            np.append(self.posterior_variance[1], self.posterior_variance[1:])
        )
        self.posterior_mean_coef1 = (
            betas * np.sqrt(self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
        )
        self.posterior_mean_coef2 = (
            (1.0 - self.alphas_cumprod_prev)
            * np.sqrt(alphas)
            / (1.0 - self.alphas_cumprod)
        )

        self.l2_loss = lambda a, b: (a - b) ** 2  # th.nn.MSELoss(reduction='none')  # must be None for handling mask later on.
        
        

    def masked_l2(self, a, b, mask):
        # assuming a.shape == b.shape == bs, J, Jdim, seqlen
        # assuming mask.shape == bs, 1, 1, seqlen
        loss = self.l2_loss(a, b)
        loss = sum_flat(loss * mask.float())  # gives \sigma_euclidean over unmasked elements
        n_entries = a.shape[1] * a.shape[2]
        non_zero_elements = sum_flat(mask) * n_entries ## non-zero-elements
        # print('mask', mask.shape)
        # print('non_zero_elements', non_zero_elements)
        # print('loss', loss)
        mse_loss_val = loss / non_zero_elements
        # print('mse_loss_val', mse_loss_val)
        return mse_loss_val


    def q_mean_variance(self, x_start, t):
        """
        Get the distribution q(x_t | x_0). # q-mean-variance #

        :param x_start: the [N x C x ...] tensor of noiseless inputs.
        :param t: the number of diffusion steps (minus 1). Here, 0 means one step.
        :return: A tuple (mean, variance, log_variance), all of x_start's shape.
        """
        mean = (
            _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
        )
        variance = _extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape)
        log_variance = _extract_into_tensor(
            self.log_one_minus_alphas_cumprod, t, x_start.shape
        )
        return mean, variance, log_variance

    def q_sample(self, x_start, t, noise=None):
        """
        Diffuse the dataset for a given number of diffusion steps.

        In other words, sample from q(x_t | x_0). ## q pos

        :param x_start: the initial dataset batch.
        :param t: the number of diffusion steps (minus 1). Here, 0 means one step.
        :param noise: if specified, the split-out normal noise.
        :return: A noisy version of x_start.
        """
        if noise is None:
            noise = th.randn_like(x_start)
        assert noise.shape == x_start.shape
        return (
            _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
            + _extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape)
            * noise
        )

    ## q_sample, 
    def q_posterior_mean_variance(self, x_start, x_t, t): ## q_posterior
        """
        Compute the mean and variance of the diffusion posterior:

            q(x_{t-1} | x_t, x_0)

        """
        assert x_start.shape == x_t.shape
        posterior_mean = ( # posterior mean and variance #
            _extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start
            + _extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t
        )
        posterior_variance = _extract_into_tensor(self.posterior_variance, t, x_t.shape)
        posterior_log_variance_clipped = _extract_into_tensor(
            self.posterior_log_variance_clipped, t, x_t.shape
        )
        assert (
            posterior_mean.shape[0]
            == posterior_variance.shape[0]
            == posterior_log_variance_clipped.shape[0]
            == x_start.shape[0]
        )
        return posterior_mean, posterior_variance, posterior_log_variance_clipped

    def p_mean_variance( ## get mean data ##
        self, model, x, t, clip_denoised=True, denoised_fn=None, model_kwargs=None
    ):
        """
        Apply the model to get p(x_{t-1} | x_t), as well as a prediction of
        the initial x, x_0.

        :param model: the model, which takes a signal and a batch of timesteps
                      as input.
        :param x: the [N x C x ...] tensor at time t.
        :param t: a 1-D Tensor of timesteps.
        :param clip_denoised: if True, clip the denoised signal into [-1, 1]. # # clip the denoised signal ##
        :param denoised_fn: if not None, a function which applies to the
            x_start prediction before it is used to sample. Applies before
            clip_denoised. # denoised fn
        :param model_kwargs: if not None, a dict of extra keyword arguments to
            pass to the model. This can be used for conditioning.
        :return: a dict with the following keys:
                 - 'mean': the model mean output.
                 - 'variance': the model variance output.
                 - 'log_variance': the log of 'variance'.
                 - 'pred_xstart': the prediction for x_0.
        """
        if model_kwargs is None:
            model_kwargs = {}
        
        ## === version 1 -> predict x_start at the rel domain === ##
        ### == x -> formulated as model_inputs == ###
        ### == TODO: process x_start from the predicted value (relative positions, signed distances) with base_pts and baes_normals == ###
        B = x['input_data']['base_pts'].shape[0]
        assert t.shape == (B,)
        
        input_data = x['input_data']
        
        ## dec_out and out ## ## output dict ##
        out_dict = model.model.dec_latents_to_joints_with_t(x, input_data, self._scale_timesteps(t).clone())
        
        
        rt_dict = {}
        # # }[self.model_var_type]
        # ### === model variance and log_variance === ### ## self.posterior_variance, self.
        model_variance, model_log_variance = self.posterior_variance, self.posterior_log_variance_clipped

  
            
        if self.diff_basejtse:
            
            base_jts_e_feats = x['base_jts_e_feats'] ### x_t values here ###
            pred_basejtse_seq_latents = out_dict['base_jts_e_feats']
            ### q-sampled latent mean here ###
            basejtse_seq_latents_mean, _, _ = self.q_posterior_mean_variance(
                x_start=pred_basejtse_seq_latents.permute(1, 0, 2), x_t=base_jts_e_feats.permute(1, 0, 2), t=t
            )
            
            basejtse_seq_latents_mean = basejtse_seq_latents_mean.permute(1, 0, 2)
            
            basejtse_seq_latents_variance = _extract_into_tensor(model_variance, t, basejtse_seq_latents_mean.permute(1, 0, 2).shape).permute(1, 0, 2)
            basejtse_seq_latents_log_variance = _extract_into_tensor(model_log_variance, t, basejtse_seq_latents_mean.permute(1, 0, 2).shape).permute(1, 0, 2)

            # base_jts_e_feats = out_dict["base_jts_e_feats"]
            dec_e_along_normals = out_dict["dec_e_along_normals"]
            dec_e_vt_normals = out_dict["dec_e_vt_normals"]
            dec_d = out_dict["dec_d"]
            rel_vel_dec = out_dict["rel_vel_dec"]
            
            basejtse_seq_rt_dict = {
                ### baesjtse seq latents ###
                "basejtse_seq_latents_mean": basejtse_seq_latents_mean,
                "basejtse_seq_latents_variance": basejtse_seq_latents_variance,
                "basejtse_seq_latents_log_variance": basejtse_seq_latents_log_variance,
                
                "dec_e_along_normals": dec_e_along_normals,
                "dec_e_vt_normals": dec_e_vt_normals,
                "dec_d": dec_d,
                "rel_vel_dec": rel_vel_dec
            }
        else:
            basejtse_seq_rt_dict = {}
        
        # rt_dict.update(jts_seq_rt_dict)
        # rt_dict.update(basejtsrel_seq_rt_dict)
        rt_dict.update(basejtse_seq_rt_dict)
        
        return rt_dict

    def _predict_xstart_from_eps(self, x_t, t, eps):
        assert x_t.shape == eps.shape
        return ( # extract into tensor #
            _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t
            - _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * eps
        )

    def _predict_xstart_from_xprev(self, x_t, t, xprev):
        assert x_t.shape == xprev.shape
        return (  # (xprev - coef2*x_t) / coef1
            _extract_into_tensor(1.0 / self.posterior_mean_coef1, t, x_t.shape) * xprev
            - _extract_into_tensor(
                self.posterior_mean_coef2 / self.posterior_mean_coef1, t, x_t.shape
            )
            * x_t
        )

    def _predict_eps_from_xstart(self, x_t, t, pred_xstart):
        return (
            _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t
            - pred_xstart
        ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)

    def _scale_timesteps(self, t):
        if self.rescale_timesteps:
            return t.float() * (1000.0 / self.num_timesteps)
        return t

    def condition_mean(self, cond_fn, p_mean_var, x, t, model_kwargs=None):
        """
        Compute the mean for the previous step, given a function cond_fn that
        computes the gradient of a conditional log probability with respect to
        x. In particular, cond_fn computes grad(log(p(y|x))), and we want to
        condition on y.

        This uses the conditioning strategy from Sohl-Dickstein et al. (2015).
        """
        gradient = cond_fn(x, self._scale_timesteps(t), **model_kwargs)
        new_mean = (
            p_mean_var["mean"].float() + p_mean_var["variance"] * gradient.float()
        )
        return new_mean

    def condition_mean_with_grad(self, cond_fn, p_mean_var, x, t, model_kwargs=None):
        """
        Compute the mean for the previous step, given a function cond_fn that
        computes the gradient of a conditional log probability with respect to
        x. In particular, cond_fn computes grad(log(p(y|x))), and we want to
        condition on y.

        This uses the conditioning strategy from Sohl-Dickstein et al. (2015).
        """
        gradient = cond_fn(x, t, p_mean_var, **model_kwargs)
        new_mean = (
            p_mean_var["mean"].float() + p_mean_var["variance"] * gradient.float()
        )
        return new_mean

    def condition_score(self, cond_fn, p_mean_var, x, t, model_kwargs=None):
        """
        Compute what the p_mean_variance output would have been, should the
        model's score function be conditioned by cond_fn.

        See condition_mean() for details on cond_fn.

        Unlike condition_mean(), this instead uses the conditioning strategy
        from Song et al (2020).
        """
        alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape)

        eps = self._predict_eps_from_xstart(x, t, p_mean_var["pred_xstart"])
        eps = eps - (1 - alpha_bar).sqrt() * cond_fn(
            x, self._scale_timesteps(t), **model_kwargs
        )

        out = p_mean_var.copy()
        out["pred_xstart"] = self._predict_xstart_from_eps(x, t, eps)
        out["mean"], _, _ = self.q_posterior_mean_variance(
            x_start=out["pred_xstart"], x_t=x, t=t
        )
        return out

    def condition_score_with_grad(self, cond_fn, p_mean_var, x, t, model_kwargs=None):
        """
        Compute what the p_mean_variance output would have been, should the
        model's score function be conditioned by cond_fn.

        See condition_mean() for details on cond_fn.

        Unlike condition_mean(), this instead uses the conditioning strategy
        from Song et al (2020).
        """
        alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape)

        eps = self._predict_eps_from_xstart(x, t, p_mean_var["pred_xstart"])
        eps = eps - (1 - alpha_bar).sqrt() * cond_fn(
            x, t, p_mean_var, **model_kwargs
        )

        out = p_mean_var.copy()
        out["pred_xstart"] = self._predict_xstart_from_eps(x, t, eps)
        out["mean"], _, _ = self.q_posterior_mean_variance(
            x_start=out["pred_xstart"], x_t=x, t=t
        )
        return out

    def p_sample(
        self,
        model,
        x, # psampele #
        t,
        clip_denoised=True,
        denoised_fn=None,
        cond_fn=None,
        model_kwargs=None,
        const_noise=False,
    ):
        """
        Sample x_{t-1} from the model at the given timestep.

        :param model: the model to sample from.
        :param x: the current tensor at x_{t-1}.
        :param t: the value of t, starting at 0 for the first diffusion step.
        :param clip_denoised: if True, clip the x_start prediction to [-1, 1].
        :param denoised_fn: if not None, a function which applies to the
            x_start prediction before it is used to sample.
        :param cond_fn: if not None, this is a gradient function that acts
                        similarly to the model.
        :param model_kwargs: if not None, a dict of extra keyword arguments to
            pass to the model. This can be used for conditioning.
        :return: a dict containing the following keys:
                 - 'sample': a random sample from the model.
                 - 'pred_xstart': a prediction of x_0.
        """
        out = self.p_mean_variance(
            model,
            x,
            t, # starting   
            clip_denoised=clip_denoised,
            denoised_fn=denoised_fn,
            model_kwargs=model_kwargs,
        )
        
        rt_dict = {}
    
    
        
        if self.diff_basejtse:
            ##### ===== Sample for basejtse_seq_latents_sample ===== #####
            ### rel_base_pts_outputs mask ###
            basejtse_seq_latents_noise = th.randn_like(x['base_jts_e_feats'])
            # print('const_noise', const_noise)
            if const_noise:
                basejtse_seq_latents_noise = basejtse_seq_latents_noise[[0]].repeat(x['base_jts_e_feats'].shape[0], 1, 1, 1, 1)
            basejtse_seq_latents_nonzero_mask = (
                (t != 0).float().view(-1, *([1] * (len(x['base_jts_e_feats'].shape) - 1)))
            )  # no noise when t == 0
            #### ==== basejtsrel_seq_latents ===== ####
            basejtse_seq_latents_sample = out["basejtse_seq_latents_mean"].permute(1, 0, 2) + basejtse_seq_latents_nonzero_mask * th.exp(0.5 * out["basejtse_seq_latents_log_variance"].permute(1, 0, 2)) * basejtse_seq_latents_noise.permute(1, 0, 2)
            basejtse_seq_latents_sample = basejtse_seq_latents_sample.permute(1, 0, 2)
            #### ==== basejtsrel_seq_latents ===== ####
            ##### ===== Sample for basejtse_seq_latents_sample ===== #####
            
            dec_e_along_normals = out["dec_e_along_normals"] ##
            dec_e_vt_normals = out["dec_e_vt_normals"]
            dec_d = out['dec_d']
            rel_vel_dec = out['rel_vel_dec']

            basejtse_rt_dict = {
                "basejtse_seq_latents_sample": basejtse_seq_latents_sample,
                "dec_e_along_normals": dec_e_along_normals,
                "dec_e_vt_normals": dec_e_vt_normals,
                "dec_d": dec_d,
                "rel_vel_dec": rel_vel_dec
            }
        else:
            basejtse_rt_dict = {}
            
        
        # rt_dict.update(jts_seq_rt_dict)
        # rt_dict.update(basejtsrel_rt_dict)
        rt_dict.update(basejtse_rt_dict)
        
        
        return rt_dict

    def p_sample_with_grad(
        self,
        model,
        x,
        t,
        clip_denoised=True,
        denoised_fn=None,
        cond_fn=None,
        model_kwargs=None,
    ):
        """
        Sample x_{t-1} from the model at the given timestep.

        :param model: the model to sample from.
        :param x: the current tensor at x_{t-1}.
        :param t: the value of t, starting at 0 for the first diffusion step.
        :param clip_denoised: if True, clip the x_start prediction to [-1, 1].
        :param denoised_fn: if not None, a function which applies to the
            x_start prediction before it is used to sample.
        :param cond_fn: if not None, this is a gradient function that acts
                        similarly to the model.
        :param model_kwargs: if not None, a dict of extra keyword arguments to
            pass to the model. This can be used for conditioning.
        :return: a dict containing the following keys:
                 - 'sample': a random sample from the model.
                 - 'pred_xstart': a prediction of x_0.
        """
        with th.enable_grad():
            x = x.detach().requires_grad_()
            out = self.p_mean_variance(
                model,
                x,
                t,
                clip_denoised=clip_denoised,
                denoised_fn=denoised_fn,
                model_kwargs=model_kwargs,
            )
            noise = th.randn_like(x)
            nonzero_mask = (
                (t != 0).float().view(-1, *([1] * (len(x.shape) - 1)))
            )  # no noise when t == 0
            if cond_fn is not None:
                out["mean"] = self.condition_mean_with_grad(
                    cond_fn, out, x, t, model_kwargs=model_kwargs
                )
        sample = out["mean"] + nonzero_mask * th.exp(0.5 * out["log_variance"]) * noise
        return {"sample": sample, "pred_xstart": out["pred_xstart"].detach()}

    def p_sample_loop(
        self,
        model,
        shape,
        noise=None,
        clip_denoised=True,
        denoised_fn=None,
        cond_fn=None,
        model_kwargs=None,
        device=None,
        progress=False,
        skip_timesteps=0,
        init_image=None,
        randomize_class=False,
        cond_fn_with_grad=False,
        dump_steps=None,
        const_noise=False,
        st_timestep=None,
    ): ## 
        """
        Generate samples from the model.

        :param model: the model module.
        :param shape: the shape of the samples, (N, C, H, W).
        :param noise: if specified, the noise from the encoder to sample.
                      Should be of the same shape as `shape`.
        :param clip_denoised: if True, clip x_start predictions to [-1, 1].
        :param denoised_fn: if not None, a function which applies to the
            x_start prediction before it is used to sample.
        :param cond_fn: if not None, this is a gradient function that acts
                        similarly to the model.
        :param model_kwargs: if not None, a dict of extra keyword arguments to
            pass to the model. This can be used for conditioning.
        :param device: if specified, the device to create the samples on.
                       If not specified, use a model parameter's device.
        :param progress: if True, show a tqdm progress bar.
        :param const_noise: If True, will noise all samples with the same noise throughout sampling
        :return: a non-differentiable batch of samples.
        """
        final = None # 
        if dump_steps is not None: ## dump steps is not None ##
            dump = []

        # function, yield, enumerate! -> 
        for i, sample in enumerate(self.p_sample_loop_progressive(
            model, # p_sample #
            shape, # p_sample #
            noise=noise,
            clip_denoised=clip_denoised,
            denoised_fn=denoised_fn,
            cond_fn=cond_fn,
            model_kwargs=model_kwargs,
            device=device,
            progress=progress,
            skip_timesteps=skip_timesteps,
            init_image=init_image,
            randomize_class=randomize_class,
            cond_fn_with_grad=cond_fn_with_grad,
            const_noise=const_noise, # the same noise #
            st_timestep=st_timestep,
        )):
            if dump_steps is not None and i in dump_steps:
                dump.append(deepcopy(sample))
            final = sample
        if dump_steps is not None:
            return dump
        return final

    def p_sample_loop_progressive(
        self,
        model,
        shape,
        noise=None,
        clip_denoised=True,
        denoised_fn=None,
        cond_fn=None,
        model_kwargs=None,
        device=None,
        progress=False,
        skip_timesteps=0,
        init_image=None,
        randomize_class=False,
        cond_fn_with_grad=False,
        const_noise=False,
        st_timestep=None,
    ): # 
        """
        Generate samples from the model and yield intermediate samples from
        each timestep of diffusion.

        Arguments are the same as p_sample_loop().
        Returns a generator over dicts, where each dict is the return value of
        p_sample().
        """
        ####### ==== a conditional ssampling from init_images here!!! ==== #######
        ## === give joints shape here === ##
        ### ==== set the shape for sampling ==== ###
        ### === init image should not be none === ###
        base_pts = init_image['base_pts']
        base_normals = init_image['base_normals'] 
        rhand_joints = init_image['rhand_joints']
        vel_obj_pts_to_hand_pts = init_image['vel_obj_pts_to_hand_pts']
        obj_pts_disp = init_image['obj_pts_disp']
        
        avg_joints_sequence = torch.mean(rhand_joints, dim=1) # bsz x nnj x 3 ---> for rhand joints ##
        
        if 'sampled_base_pts_nearest_obj_pc' in init_image:
            ambient_init_image = {
                'sampled_base_pts_nearest_obj_pc': init_image['sampled_base_pts_nearest_obj_pc'],
                'sampled_base_pts_nearest_obj_vns': init_image['sampled_base_pts_nearest_obj_vns'],
            }
            
        # per_frame_avg_disp_along_normals, per_frame_std_disp_along_normals
        # per_frame_avg_disp_vt_normals, per_frame_std_disp_vt_normals
        if self.args.wo_e_normalization:
            init_image['per_frame_avg_disp_along_normals'] = torch.zeros_like(init_image['per_frame_avg_disp_along_normals'])
            init_image['per_frame_avg_disp_vt_normals'] = torch.zeros_like(init_image['per_frame_avg_disp_vt_normals'])
            init_image['per_frame_std_disp_along_normals'] = torch.ones_like(init_image['per_frame_std_disp_along_normals'])
            init_image['per_frame_std_disp_vt_normals'] = torch.ones_like(init_image['per_frame_std_disp_vt_normals'])
            
        if self.args.wo_rel_normalization: # per_frame_avg_joints_rel, per_frame_std_joints_rel
            # init_image['per_frame_avg_joints_rel'] = torch.zeros_like(init_image['per_frame_avg_joints_rel'])
            init_image['per_frame_std_joints_rel'] = torch.ones_like(init_image['per_frame_std_joints_rel'])

            
        init_image_avg_std_stats = {
            'rhand_joints': init_image['rhand_joints'],
            'per_frame_avg_joints_rel': init_image['per_frame_avg_joints_rel'],
            'per_frame_std_joints_rel': init_image['per_frame_std_joints_rel'],
            'per_frame_avg_joints_dists_rel': init_image['per_frame_avg_joints_dists_rel'],
            'per_frame_std_joints_dists_rel': init_image['per_frame_std_joints_dists_rel'],
        }
        
        if device is None:
            device = next(model.parameters()).device
        assert isinstance(shape, (tuple, list)) 
        
        # if noise is not None:
        #     img = noise
        # else:
        #     img = th.randn(*shape, device=device)

        ### sample progresssive ###
        # if skip_timesteps and init_image is None:
        #     rhand_joints = th.zeros_like(img)

        # indicies
        indices = list(range(self.num_timesteps - skip_timesteps))[::-1]
        if st_timestep is not None: ## indices 
            indices = indices[-st_timestep: ]
            print(f"st_timestep: {st_timestep}, indices: {indices}")


        
        joints_scaling_factor = 5.
        
        # rel_base_pts_to_rhand_joints: bsz x ws x nnj x nnb x 3 ##
        # base_pts: bsz x nnb x 3 # # base_normals: bsz x nnb x 3 ##
        rel_base_pts_to_rhand_joints = rhand_joints.unsqueeze(-2) - base_pts.unsqueeze(1).unsqueeze(1) # bsz x nnj x nnb x 3 # rhand_joints ##
        init_image['per_frame_avg_joints_rel'] = torch.mean(rel_base_pts_to_rhand_joints, dim=1, keepdim=True)
        # x_start['per_frame_avg_joints_rel'] = torch
        # bsz x ws x nnj x nnb x 3 #
        rel_base_pts_to_rhand_joints = (rel_base_pts_to_rhand_joints - init_image['per_frame_avg_joints_rel']) / init_image['per_frame_std_joints_rel']
        
        if self.denoising_stra == "rep":
            ''' Normalization Strategy 4 '''
            my_t = th.ones([shape[0]], device=device, dtype=th.long) * indices[0] # t con
            
            # normed_rhand_joints = (rhand_joints - self.avg_jts.unsqueeze(0).to(rhand_joints.device)) / self.std_jts.unsqueeze(0).to(rhand_joints.device)
            # noise_rhand_joints = th.randn_like(normed_rhand_joints)
            # pert_normed_rhand_joints = self.q_sample(normed_rhand_joints, my_t, noise=noise_rhand_joints)
            
            # 
            # ### scale rhand joints ##
            # # rhand joints: bsz x ws x nnj x 3
            exp_rhand_joints = rhand_joints.view(rhand_joints.size(0), rhand_joints.size(1) * rhand_joints.size(2), 3)
            # ## avg exp rhadn joints ##
            
            if self.args.jts_sclae_stra == "std":
                avg_exp_rhand_joints = torch.mean(exp_rhand_joints, dim=1, keepdim=True)
                extents_rhand_joints = torch.std(exp_rhand_joints, dim=1, keepdim=True)
            elif self.args.jts_sclae_stra == "bbox":
                maxx_exp_rhand_joints, _ = torch.max(exp_rhand_joints, dim=1, keepdim=True)
                minn_exp_rhand_joints, _ = torch.min(exp_rhand_joints, dim=1, keepdim=True)
                avg_exp_rhand_joints = (maxx_exp_rhand_joints + minn_exp_rhand_joints) / 2. # avg_exp_rhand_joints
                extents_rhand_joints = maxx_exp_rhand_joints - minn_exp_rhand_joints ### bsz x 1 x 3 # 
                extents_rhand_joints = torch.sqrt(torch.sum(extents_rhand_joints ** 2, dim=-1, keepdim=True))
            else:
                raise ValueError(f"Unrecognized jts_sclae_stra: {self.args.jts_sclae_stra}")
            
            rhand_joints = (rhand_joints - avg_exp_rhand_joints.unsqueeze(1) ) / extents_rhand_joints.unsqueeze(1)
            
            scaled_rhand_joints = rhand_joints * joints_scaling_factor
            noise_scaled_rhand_joints = th.randn_like(scaled_rhand_joints)
            pert_scaled_rhand_joints = self.q_sample(scaled_rhand_joints, my_t, noise=noise_scaled_rhand_joints)
            
            # pert_rhand_joints: bsz x nnj x 3 ## -> 
            # pert_rhand_joints = pert_normed_rhand_joints * self.std_jts.unsqueeze(0).to(rhand_joints.device) + self.avg_jts.unsqueeze(0).to(rhand_joints.device)
            
            
            ### Calculate moving related energies ###
            # denormed_rhand_joints: bsz x nf x nnj x 3 -> denormed rhand joints here #
            # pert rhand joints # ## denormed base pts to rhand joints -> denormed rel positions ##
            denormed_rel_base_pts_to_rhand_joints = rel_base_pts_to_rhand_joints * init_image['per_frame_std_joints_rel'] + init_image['per_frame_avg_joints_rel']
            denormed_rhand_joints = rhand_joints * extents_rhand_joints.unsqueeze(1) + avg_exp_rhand_joints.unsqueeze(1)
            denormed_dist_rel_base_pts_to_pert_rhand_joints = torch.sum(
                denormed_rel_base_pts_to_rhand_joints * base_normals.unsqueeze(1).unsqueeze(1), dim=-1
            ) ## l2 real base pts 
            k_f = 1. ## l2 rel base pts to pert rhand joints ##
            # l2_rel_base_pts_to_pert_rhand_joints: bsz x nf x nnj x nnb #
            l2_rel_base_pts_to_pert_rhand_joints = torch.norm(denormed_rel_base_pts_to_rhand_joints, dim=-1)
            ### att_forces ##
            att_forces = torch.exp(-k_f * l2_rel_base_pts_to_pert_rhand_joints) # bsz x nf x nnj x nnb #
            # bsz x (ws - 1) x nnj x nnb #
            att_forces = att_forces[:, :-1, :, :] # attraction forces -1 #
            # rhand_joints: ws x nnj x 3 # -> (ws - 1) x nnj x 3 ## rhand_joints ##
            # bsz x (ws - 1) x nnj x 3 --> displacements s#
            denormed_rhand_joints_disp = denormed_rhand_joints[:, 1:, :, :] - denormed_rhand_joints[:, :-1, :, :]
            # distance -- base_normalss,; (ws - 1) x nnj x nnb x 3 --> bsz x (ws - 1) x nnj x nnb # 
            # signed_dist_base_pts_to_pert_rhand_joints_along_normal # bsz x (ws - 1) x nnj x nnb #
            signed_dist_base_pts_to_rhand_joints_along_normal = torch.sum(
                base_normals.unsqueeze(1).unsqueeze(1) * denormed_rhand_joints_disp.unsqueeze(-2), dim=-1
            )
            # rel_base_pts_to_pert_rhand_joints_vt_normal: bsz x (ws -1) x nnj x nnb x 3 -> the relative positions vertical to base normals #
            rel_base_pts_to_rhand_joints_vt_normal = denormed_rhand_joints_disp.unsqueeze(-2)  - signed_dist_base_pts_to_rhand_joints_along_normal.unsqueeze(-1) * base_normals.unsqueeze(1).unsqueeze(1)
            dist_base_pts_to_rhand_joints_vt_normal = torch.sqrt(torch.sum(
                rel_base_pts_to_rhand_joints_vt_normal ** 2, dim=-1
            ))
            k_a = 1.
            k_b = 1.
            ### 
            e_disp_rel_to_base_along_normals = k_a * att_forces * torch.abs(signed_dist_base_pts_to_rhand_joints_along_normal)
            # (ws - 1) x nnj x nnb # -> dist vt normals # ## 
            e_disp_rel_to_baes_vt_normals = k_b * att_forces * dist_base_pts_to_rhand_joints_vt_normal
            # nf x nnj x nnb ---> dist_vt_normals -> nf x nnj x nnb # # torch.sqrt() ##
            # 
            # dist_base_pts_to_pert_rhand_joints, rel_base_pts_to_pert_rhand_joints #
            # e_disp_rel_to_base_along_normals, e_disp_rel_to_baes_vt_normals #
            # print("per frmae avg jalong normals,", x_start['per_frame_avg_disp_along_normals'].size(), "per frame std along normals",  x_start['per_frame_std_disp_along_normals'].size(), "e_disp_rel_to_base_along_normals", e_disp_rel_to_base_along_normals.size(), "e_disp_rel_to_baes_vt_normals", e_disp_rel_to_baes_vt_normals.size())
            e_disp_rel_to_base_along_normals = (e_disp_rel_to_base_along_normals - init_image['per_frame_avg_disp_along_normals']) / init_image['per_frame_std_disp_along_normals']
            e_disp_rel_to_baes_vt_normals = (e_disp_rel_to_baes_vt_normals - init_image['per_frame_avg_disp_vt_normals']) / init_image['per_frame_std_disp_vt_normals']

        else:
            raise ValueError(f"Unrecognized denoising stra: {self.denoising_stra}")
        
        # denoised es #
        # prersentations --- denoisng--> 
        input_data = {
            'base_pts': base_pts,
            'base_normals': base_normals,
            # 'rel_base_pts_to_rhand_joints': rel_base_pts_to_pert_rhand_joints, 
            # 'dist_base_pts_to_rhand_joints': dist_base_pts_to_pert_rhand_joints,
            # 'pert_rhand_joints': pert_normed_rhand_joints,
            'pert_rhand_joints': pert_scaled_rhand_joints,
            'rhand_joints': rhand_joints, # 
            'rel_base_pts_to_rhand_joints': rel_base_pts_to_rhand_joints.clone(),
            'avg_joints_sequence': avg_joints_sequence,
        }
        
        if 'sampled_base_pts_nearest_obj_pc' in init_image:
            input_data.update(ambient_init_image)
            
        input_data.update(
            {
                'e_disp_rel_to_base_along_normals': e_disp_rel_to_base_along_normals, 
                'e_disp_rel_to_baes_vt_normals': e_disp_rel_to_baes_vt_normals,
                'vel_obj_pts_to_hand_pts': vel_obj_pts_to_hand_pts,
                'obj_pts_disp': obj_pts_disp
            }
        )
        # input 
        input_data.update(init_image_avg_std_stats)
        input_data['rhand_joints'] = rhand_joints # normed 
        
        
        my_t = th.tensor([indices[-1]] * shape[0], device=device)
        # clean_joint_seq_latents = model(input_data, self._scale_timesteps(my_t))
        # noise_joint_seq_latents = th.randn_like(clean_joint_seq_latents)
        # # pert_joint_seq_latents: bsz x seq x d # 
        # pert_joint_seq_latents = self.q_sample(clean_joint_seq_latents.permute(1, 0, 2).contiguous(), my_t, noise=noise_joint_seq_latents.permute(1, 0, 2).contiguous()).permute(1, 0, 2).contiguous()
        
        # clean_joint_seq_latents: seq x bs x d #
        # clean_joint_seq_latents = model(input_data, self._scale_timesteps(t).clone())
        ## pert_joint_seq_latents, pert_basejtsrel_seq_latents ##
        out_dict = model(input_data, self._scale_timesteps(my_t).clone())
        
        dec_in_dict = {}


        if self.diff_basejtse:
            ### Sample for perturbed basejtsrel seq latents ###
            basejtse_seq_latents = out_dict["base_jts_e_feats"] # 
            
            if 'base_jts_e_feats_mean' in out_dict:
                basejtse_seq_latents = out_dict['base_jts_e_feats_mean']
            
            noise_basejtse_seq_latents = th.randn_like(basejtse_seq_latents)
            pert_basejtse_seq_latents = self.q_sample(basejtse_seq_latents.permute(1, 0, 2).contiguous(), my_t, noise=noise_basejtse_seq_latents.permute(1, 0, 2).contiguous()).permute(1, 0, 2).contiguous()
            
            
            if self.args.rnd_noise:
                dec_in_dict['base_jts_e_feats'] = basejtse_seq_latents
            else:
                dec_in_dict['base_jts_e_feats'] = pert_basejtse_seq_latents
            
            dec_in_dict['base_jts_e_feats'] = pert_basejtse_seq_latents
            dec_in_dict['base_jts_e_feats_enc'] = basejtse_seq_latents
        
        # dec in dict here #
        # dec_in_dict = {
        #     "joints_seq_latents": pert_joint_seq_latents,
        #     "rel_base_pts_outputs": pert_basejtsrel_seq_latents,
        #     "base_jts_e_feats": pert_basejtse_seq_latents,
        # }
        ### !!! update for input data !!! ###
        dec_in_dict['input_data'] = input_data
        
        # input_data['pert_joint_seq_latents'] = pert_joint_seq_latents ## decoded 
        
        
        model_kwargs = {
            k: val for k, val in init_image.items() if k not in input_data
        }

        if progress:
            # Lazy import so that we don't depend on tqdm.
            from tqdm.auto import tqdm
            indices = tqdm(indices)

        for i_idx, i in enumerate(indices):
            t = th.tensor([i] * shape[0], device=device)
            if randomize_class and 'y' in model_kwargs:
                model_kwargs['y'] = th.randint(low=0, high=model.num_classes,
                                               size=model_kwargs['y'].shape, # size of y 
                                               device=model_kwargs['y'].device) # device of y 
            with th.no_grad(): # inter_optim #
                # p_sample_with_grad
                sample_fn = self.p_sample_with_grad if cond_fn_with_grad else self.p_sample
                # 
                
                out = sample_fn(
                    model,
                    dec_in_dict, ## sample from input data ##
                    t,
                    clip_denoised=clip_denoised,
                    denoised_fn=denoised_fn,
                    cond_fn=cond_fn,
                    model_kwargs=model_kwargs,
                    const_noise=const_noise,
                )
                # yield out
            # img = out["sample"]
            # dec_clean_joint_seq = out["dec_clean_joint_seq"]
            
            input_data = {}
            dec_in_dict = {}
            # dec_clean_joint_seq = dec_clean_joint_seq * extents_rhand_joints.unsqueeze(1) + avg_exp_rhand_joints.unsqueeze(1) # latent spapce ->  # single step noise 
            
            
            ## gaussian diffusion ours ##
            # basejtsrel_output: bsz x nf x nnj x nnb x 3 --> rel outputs #
            # diff_basejtse ## basejtse ##
            if self.diff_basejtse: ## seq latents ## seq latents ##
                basejtse_seq_latents_sample = out["basejtse_seq_latents_sample"]
                dec_e_along_normals = out["dec_e_along_normals"]
                dec_e_vt_normals = out["dec_e_vt_normals"]
                dec_d = out["dec_d"]
                rel_vel_dec = out["rel_vel_dec"]
            
                dec_e_along_normals = dec_e_along_normals * init_image['per_frame_std_disp_along_normals'] + init_image['per_frame_avg_disp_along_normals']
                dec_e_vt_normals = dec_e_vt_normals * init_image['per_frame_std_disp_vt_normals'] + init_image['per_frame_avg_disp_vt_normals']
                
                ## model constraints and model impacts from object a to object c ##
                basejtse_seq_input_dict = {
                    'e_disp_rel_to_base_along_normals': dec_e_along_normals,
                    'e_disp_rel_to_baes_vt_normals': dec_e_vt_normals, ### vt_normals ###
                    'rel_vel_dec': rel_vel_dec,
                    'dec_d': dec_d
                }
                # basejts e seq dec in dict #
                basejtse_seq_dec_in_dict = {
                    "base_jts_e_feats": basejtse_seq_latents_sample,
                }
            else:
                basejtse_seq_input_dict = {}
                basejtse_seq_dec_in_dict = {}
            

            input_data = {
                'base_pts': base_pts,
                'base_normals': base_normals,
                'rhand_joints': rhand_joints,
            }
            
            ## jts seq 
            input_data.update(jts_seq_input_dict)
            input_data.update(basejtsrel_seq_input_dict)
            input_data.update(basejtse_seq_input_dict)
            
            if 'sampled_base_pts_nearest_obj_pc' in init_image:
                input_data.update(ambient_init_image)
            input_data.update(init_image_avg_std_stats)
            input_data['rhand_joints']=  rhand_joints
            
            ## input_data ##
            ### sampled rhand joints ###
            if 'sampled_rhand_joints' not in input_data: # sampled_rhand_joints --> sampled rhand joints 
                sampled_rhand_joints = rhand_joints * extents_rhand_joints.unsqueeze(1) + avg_exp_rhand_joints.unsqueeze(1) # latent spapce ->  # single step noise #
                input_data['sampled_rhand_joints'] = sampled_rhand_joints # sampled_rhand_joints #
            # sampled_rhand_joints --> sampled_rhand_joints #
            dec_in_dict.update(jts_seq_dec_in_dict)
            dec_in_dict.update(basejtse_seq_dec_in_dict)
            dec_in_dict.update(basejtsrel_seq_dec_in_dict)
            
            dec_in_dict['input_data'] = input_data
            
            # dec_in_dict = {
                
                
            #     "input_data": input_data,
            # }
            
            
            yield input_data
                        

    def _vb_terms_bpd(
        self, model, x_start, x_t, t, clip_denoised=True, model_kwargs=None
    ):
        """
        Get a term for the variational lower-bound.

        The resulting units are bits (rather than nats, as one might expect).
        This allows for comparison to other papers.

        :return: a dict with the following keys:
                 - 'output': a shape [N] tensor of NLLs or KLs.
                 - 'pred_xstart': the x_0 predictions.
        """
        true_mean, _, true_log_variance_clipped = self.q_posterior_mean_variance(
            x_start=x_start, x_t=x_t, t=t
        )
        out = self.p_mean_variance(
            model, x_t, t, clip_denoised=clip_denoised, model_kwargs=model_kwargs
        )
        kl = normal_kl(
            true_mean, true_log_variance_clipped, out["mean"], out["log_variance"]
        )
        kl = mean_flat(kl) / np.log(2.0)

        decoder_nll = -discretized_gaussian_log_likelihood(
            x_start, means=out["mean"], log_scales=0.5 * out["log_variance"]
        )
        assert decoder_nll.shape == x_start.shape
        decoder_nll = mean_flat(decoder_nll) / np.log(2.0)

        # At the first timestep return the decoder NLL,
        # otherwise return KL(q(x_{t-1}|x_t,x_0) || p(x_{t-1}|x_t))
        output = th.where((t == 0), decoder_nll, kl)
        return {"output": output, "pred_xstart": out["pred_xstart"]}


    def training_losses(self, model, x_start, t, model_kwargs=None, noise=None, dataset=None):
        """ # training losses # training losses for rel/dist representations ## 
        Compute training losses for a single timestep.

        :param model: the model to evaluate loss on.
        :param x_start: the [N x C x ...] tensor of inputs.
        :param t: a batch of timestep indices.
        :param model_kwargs: if not None, a dict of extra keyword arguments to
            pass to the model. This can be used for conditioning.
        :param noise: if specified, the specific Gaussian noise to try to remove.
        :return: a dict with the key "loss" containing a tensor of shape [N].
                 Some mean or variance settings may also have other keys.
        """
        
        if self.args.train_diff: # set enc to evals #
            # print(f"Setitng encoders to eval mode")
            model.model.set_enc_to_eval()

        enc = model.model ## model.model
        mask = model_kwargs['y']['mask'] ## rot2xyz
        get_xyz = lambda sample: enc.rot2xyz(sample, mask=None, pose_rep=enc.pose_rep, translation=enc.translation,
                                             glob=enc.glob, ## rot2xyz; ##
                                             # jointstype='vertices',  # 3.4 iter/sec # USED ALSO IN MotionCLIP
                                             jointstype='smpl',  # 3.4 iter/sec
                                             vertstrans=False)
        # bsz x ws x nnj x 3 #
        # base_pts: bsz x nnb x 3 #
        # base_normals: bsz x nnb x 3 #
        # bsz x ws x nnjts x 3 #
        rhand_joints = x_start['rhand_joints']
        # bsz x nnbase x 3 #
        base_pts = x_start['base_pts']
        # bsz x ws x nnbase x 3 #
        base_normals = x_start['base_normals']
        vel_obj_pts_to_hand_pts = x_start["vel_obj_pts_to_hand_pts"]
        obj_pts_disp = x_start["obj_pts_disp"]
        dist_base_pts_to_rhand_joints = x_start['dist_base_pts_to_rhand_joints']
        avg_joints_sequence = torch.mean(rhand_joints, dim=1) # bsz x nnj x 3 ---> for rhand joints ##
        # # bsz x ws x nnjts x nnbase x 3 #
        # rel_base_pts_to_rhand_joints = x_start['rel_base_pts_to_rhand_joints']
        # # bsz x ws x nnjts x nnbase #
        # dist_base_pts_to_rhand_joints = x_start['dist_base_pts_to_rhand_joints']
        ## no diffu
        # normalization strategy for joints and that for the representation values #
        if 'sampled_base_pts_nearest_obj_pc' in x_start:
            ambient_xstart_dict = {
                'sampled_base_pts_nearest_obj_pc': x_start['sampled_base_pts_nearest_obj_pc'],
                'sampled_base_pts_nearest_obj_vns': x_start['sampled_base_pts_nearest_obj_vns'],
            }

        # per_frame_avg_disp_along_normals, per_frame_std_disp_along_normals
        # per_frame_avg_disp_vt_normals, per_frame_std_disp_vt_normals
        if self.args.wo_e_normalization:
            x_start['per_frame_avg_disp_along_normals'] = torch.zeros_like(x_start['per_frame_avg_disp_along_normals'])
            x_start['per_frame_avg_disp_vt_normals'] = torch.zeros_like(x_start['per_frame_avg_disp_vt_normals'])
            x_start['per_frame_std_disp_along_normals'] = torch.ones_like(x_start['per_frame_std_disp_along_normals'])
            x_start['per_frame_std_disp_vt_normals'] = torch.ones_like(x_start['per_frame_std_disp_vt_normals'])
        if self.args.wo_rel_normalization: # per_frame_avg_joints_rel, per_frame_std_joints_rel
            # x_start['per_frame_avg_joints_rel'] = torch.zeros_like(x_start['per_frame_avg_joints_rel'])
            
            x_start['per_frame_std_joints_rel'] = torch.ones_like(x_start['per_frame_std_joints_rel'])

        ## rel_base_pts_to_rhand_joints: bsz x ws x nnj x nnb x 3 ##
        # base_pts: bsz x nnb x 3 # # base_normals: bsz x nnb x 3 ##
        rel_base_pts_to_rhand_joints = rhand_joints.unsqueeze(-2) - base_pts.unsqueeze(1).unsqueeze(1) # bsz x nnj x nnb x 3 # rhand_joints ##
        x_start['per_frame_avg_joints_rel'] = torch.mean(rel_base_pts_to_rhand_joints, dim=1, keepdim=True)
        # x_start['per_frame_avg_joints_rel'] = torch
        # bsz x ws x nnj x nnb x 3 #
        rel_base_pts_to_rhand_joints = (rel_base_pts_to_rhand_joints - x_start['per_frame_avg_joints_rel']) / x_start['per_frame_std_joints_rel']
        
        # data normalization;
        # construct statistics, normalize values #
        joints_scaling_factor = 5.
        ''' GET rel and dists ''' 
        if self.denoising_stra == "rep":
            # bsz x ws x nnj x nnb x 3 #
            # avg_jts: 1 x nnj x 3
            # std_jts: 1 x nnj x 3 # rhand joints 
            # rhand_joints: bsz x ws x nnj x 3; normalize rhand joitns #
            normed_rhand_joints = (rhand_joints - self.avg_jts.unsqueeze(0).to(rhand_joints.device)) / self.std_jts.unsqueeze(0).to(rhand_joints.device)
            noise_rhand_joints = th.randn_like(normed_rhand_joints)
            pert_normed_rhand_joints = self.q_sample(normed_rhand_joints, t, noise=noise_rhand_joints)
            
            # 
            ### scale rhand joints ##
            # rhand joints: bsz x ws x nnj x 3 ## each joint 1 x 3 -> normalization #
            
            exp_rhand_joints = rhand_joints.view(rhand_joints.size(0), rhand_joints.size(1) * rhand_joints.size(2), 3)
            
            if self.args.jts_sclae_stra == "std":
                avg_exp_rhand_joints = torch.mean(exp_rhand_joints, dim=1, keepdim=True)
                extents_rhand_joints = torch.std(exp_rhand_joints, dim=1, keepdim=True)
            elif self.args.jts_sclae_stra == "bbox":
                ### bounding box ###
                maxx_exp_rhand_joints, _ = torch.max(exp_rhand_joints, dim=1, keepdim=True)
                minn_exp_rhand_joints, _ = torch.min(exp_rhand_joints, dim=1, keepdim=True)
                avg_exp_rhand_joints = (maxx_exp_rhand_joints + minn_exp_rhand_joints) / 2. # avg_exp_rhand_joints
                # 
                extents_rhand_joints = maxx_exp_rhand_joints - minn_exp_rhand_joints ### bsz x 1 x 3 # 
                extents_rhand_joints = torch.sqrt(torch.sum(extents_rhand_joints ** 2, dim=-1, keepdim=True))
                ### bounding box ###
            else:
                raise ValueError(f"Unrecognized jts_scale_str: {self.args.jts_sclae_stra}")
            
            ## 
            rhand_joints = (rhand_joints - avg_exp_rhand_joints.unsqueeze(1) ) / extents_rhand_joints.unsqueeze(1)
            
            scaled_rhand_joints = rhand_joints * joints_scaling_factor
            noise_scaled_rhand_joints = th.randn_like(scaled_rhand_joints)
            pert_scaled_rhand_joints = self.q_sample(scaled_rhand_joints, t, noise=noise_scaled_rhand_joints)
            
            # pert_rhand_joints: bsz x nnj x 3 ## -> 
            pert_rhand_joints = pert_normed_rhand_joints * self.std_jts.unsqueeze(0).to(rhand_joints.device) + self.avg_jts.unsqueeze(0).to(rhand_joints.device)
            # and 
            
            
            ### Calculate moving related energies ###
            # denormed_rhand_joints: bsz x nf x nnj x 3 -> denormed rhand joints here #
            # pert rhand joints # ## denormed base pts to rhand joints -> denormed rel positions ##
            denormed_rel_base_pts_to_rhand_joints = rel_base_pts_to_rhand_joints * x_start['per_frame_std_joints_rel'] + x_start['per_frame_avg_joints_rel']
            denormed_rhand_joints = rhand_joints * extents_rhand_joints.unsqueeze(1) + avg_exp_rhand_joints.unsqueeze(1)
            denormed_dist_rel_base_pts_to_pert_rhand_joints = torch.sum(
                denormed_rel_base_pts_to_rhand_joints * base_normals.unsqueeze(1).unsqueeze(1), dim=-1
            ) ## l2 real base pts 
            k_f = 1. ## l2 rel base pts to pert rhand joints ##
            # l2_rel_base_pts_to_pert_rhand_joints: bsz x nf x nnj x nnb #
            l2_rel_base_pts_to_pert_rhand_joints = torch.norm(denormed_rel_base_pts_to_rhand_joints, dim=-1)
            ### att_forces ##
            att_forces = torch.exp(-k_f * l2_rel_base_pts_to_pert_rhand_joints) # bsz x nf x nnj x nnb #
            # bsz x (ws - 1) x nnj x nnb #
            att_forces = att_forces[:, :-1, :, :] # attraction forces -1 #
            # rhand_joints: ws x nnj x 3 # -> (ws - 1) x nnj x 3 ## rhand_joints ##
            # bsz x (ws - 1) x nnj x 3 --> displacements s#
            denormed_rhand_joints_disp = denormed_rhand_joints[:, 1:, :, :] - denormed_rhand_joints[:, :-1, :, :]
            # distance -- base_normalss,; (ws - 1) x nnj x nnb x 3 --> bsz x (ws - 1) x nnj x nnb # 
            # signed_dist_base_pts_to_pert_rhand_joints_along_normal # bsz x (ws - 1) x nnj x nnb #
            signed_dist_base_pts_to_rhand_joints_along_normal = torch.sum(
                base_normals.unsqueeze(1).unsqueeze(1) * denormed_rhand_joints_disp.unsqueeze(-2), dim=-1
            )
            # rel_base_pts_to_pert_rhand_joints_vt_normal: bsz x (ws -1) x nnj x nnb x 3 -> the relative positions vertical to base normals #
            rel_base_pts_to_rhand_joints_vt_normal = denormed_rhand_joints_disp.unsqueeze(-2)  - signed_dist_base_pts_to_rhand_joints_along_normal.unsqueeze(-1) * base_normals.unsqueeze(1).unsqueeze(1)
            dist_base_pts_to_rhand_joints_vt_normal = torch.sqrt(torch.sum(
                rel_base_pts_to_rhand_joints_vt_normal ** 2, dim=-1
            ))
            k_a = 1.
            k_b = 1.
            ### 
            e_disp_rel_to_base_along_normals = k_a * att_forces * torch.abs(signed_dist_base_pts_to_rhand_joints_along_normal)
            # (ws - 1) x nnj x nnb # -> dist vt normals # ## 
            e_disp_rel_to_baes_vt_normals = k_b * att_forces * dist_base_pts_to_rhand_joints_vt_normal
            # nf x nnj x nnb ---> dist_vt_normals -> nf x nnj x nnb # # torch.sqrt() ##
            # 
            # dist_base_pts_to_pert_rhand_joints, rel_base_pts_to_pert_rhand_joints #
            # e_disp_rel_to_base_along_normals, e_disp_rel_to_baes_vt_normals #
            # print("per frmae avg jalong normals,", x_start['per_frame_avg_disp_along_normals'].size(), "per frame std along normals",  x_start['per_frame_std_disp_along_normals'].size(), "e_disp_rel_to_base_along_normals", e_disp_rel_to_base_along_normals.size(), "e_disp_rel_to_baes_vt_normals", e_disp_rel_to_baes_vt_normals.size())
            # per_frame_avg_disp_along_normals, per_frame_std_disp_along_normals
            # per_frame_avg_disp_vt_normals, per_frame_std_disp_vt_normals
            e_disp_rel_to_base_along_normals = (e_disp_rel_to_base_along_normals - x_start['per_frame_avg_disp_along_normals']) / x_start['per_frame_std_disp_along_normals']
            e_disp_rel_to_baes_vt_normals = (e_disp_rel_to_baes_vt_normals - x_start['per_frame_avg_disp_vt_normals']) / x_start['per_frame_std_disp_vt_normals']

        else:  
            raise ValueError(f"Unrecognized denoising_stra: {self.denoising_stra}")
        ''' GET rel and dists '''
        
        
        input_data = {
          'base_pts': base_pts.clone(), # base pts ###
          'base_normals': base_normals.clone(), # base normals ### 
          'rel_base_pts_to_rhand_joints': rel_base_pts_to_rhand_joints.clone(), 
        #   'dist_base_pts_to_rhand_joints': dist_base_pts_to_pert_rhand_joints.clone(),
        #   'pert_rhand_joints': pert_normed_rhand_joints,
          # scaled_rhand_joints, pert_scaled_rhand_joints
          'pert_rhand_joints': pert_scaled_rhand_joints, # 
          'rhand_joints': rhand_joints,
          'avg_joints_sequence': avg_joints_sequence, ## bsz x nnjoints x 3 here for the avg_joints ##
        }
        if 'sampled_base_pts_nearest_obj_pc' in x_start:
            input_data.update(ambient_xstart_dict)
        # dist_base_pts_to_pert_rhand_joints, rel_base_pts_to_pert_rhand_joints #
        # e_disp_rel_to_base_along_normals, e_disp_rel_to_baes_vt_normals #
        # bsz x ws - 1 x nnj x nnb # # input_data 
        disp_dist = dist_base_pts_to_rhand_joints[:-1] # (ws - 1 ) x nnj x nnb #
        input_data.update(
            {
                # e_disp_rel_to_base_along_normals, e_disp_rel_to_baes_vt_normals
                ### e_disp_rel_to_base_along_normals: 
                'e_disp_rel_to_base_along_normals': e_disp_rel_to_base_along_normals, 
                'e_disp_rel_to_baes_vt_normals': e_disp_rel_to_baes_vt_normals, ## clean values # # the denoising space is then transformed to the latent space; noisy inputs -> latent code --> can we really denoise them correctly
                "obj_pts_disp": obj_pts_disp,
                'vel_obj_pts_to_hand_pts': vel_obj_pts_to_hand_pts,
                'disp_dist': disp_dist
            }
        )
        # input_data.update(
        #   {k: x_start[k].clone() for k in x_start if k not in input_data}
        # ) # gaussian diffusion ours ## 
        # rel_base_pts_to_rhand_joints in the input_data #
        if model_kwargs is None:
            model_kwargs = {}

        terms = {} # latents in the latent space # # sequence latents #
        
        
        if self.args.train_diff:
            with torch.no_grad():
                out_dict = model(input_data, self._scale_timesteps(t).clone())
        else:
            # clean_joint_seq_latents: seq x bs x d #
            # clean_joint_seq_latents = model(input_data, self._scale_timesteps(t).clone()) ## 
            ### the strategy of removing noise from corresponding quantities ###
            out_dict = model(input_data, self._scale_timesteps(t).clone())
            ### get model output dictionary ###
        
        KL_loss = 0.
        
        # out dict of the #
        # reumse checkpoints  #dec_in_dict
        dec_in_dict = {}
  
        if self.diff_basejtse:
            ### Sample for perturbed basejtsrel seq latents ###
            basejtse_seq_latents = out_dict["base_jts_e_feats"]
            noise_basejtse_seq_latents = th.randn_like(basejtse_seq_latents)
            pert_basejtse_seq_latents = self.q_sample(basejtse_seq_latents.permute(1, 0, 2).contiguous(), t, noise=noise_basejtse_seq_latents.permute(1, 0, 2).contiguous()).permute(1, 0, 2).contiguous()
            ### Sample for perturbed basejtsrel seq latents ###
            dec_in_dict['base_jts_e_feats'] = pert_basejtse_seq_latents
            dec_in_dict['base_jts_e_feats_enc'] = basejtse_seq_latents
            
            
            if self.args.kl_weights > 0. and "base_jts_e_feats_mean" in out_dict and not self.args.train_diff:
                # clean_joint_seq_latents: seq_len x bs x d #
                log_p_base_jts_e_seq = model_util.standard_normal_logprob(basejtse_seq_latents)
                log_p_base_jts_e_seq  = log_p_base_jts_e_seq.permute(1, 0, 2).contiguous() # 
                log_p_base_jts_e_seq = log_p_base_jts_e_seq.sum(dim=-1).mean(dim=-1)
                # log_p_joints_seq
                entropy_base_jts_e_seq = model_util.gaussian_entropy(out_dict['base_jts_e_feats_logvar'].permute(1, 2, 0)).mean(dim=-1)
                loss_prior_base_jts_e_seq =  (- log_p_base_jts_e_seq - entropy_base_jts_e_seq)
                KL_loss += loss_prior_base_jts_e_seq
        
        
        # dec_in_dict = {
        #     "joints_seq_latents": pert_joint_seq_latents,
        #     "rel_base_pts_outputs": pert_basejtsrel_seq_latents,
        #     "base_jts_e_feats": pert_basejtse_seq_latents,
        # }

        # # dec_clean_joint_seq: bsz x ws x nnj x 3
        # # dec_clena_seq_latents: seq x bs x d
        # dec_clean_joint_seq, dec_clena_seq_latents = model.model.dec_latents_to_joints_with_t(pert_joint_seq_latents, self._scale_timesteps(t).clone())
        
        # dec_clean_joint_seq: bsz x ws x nnj x 3
        # dec_clena_seq_latents: seq x bs x d
        dec_out_dict = model.model.dec_latents_to_joints_with_t(dec_in_dict, input_data, self._scale_timesteps(t).clone())
        
        
        terms['rot_mse'] = 0.

    
        if self.diff_basejtse:
            dec_base_jts_e_feats = dec_out_dict['base_jts_e_feats']
            dec_e_along_normals = dec_out_dict['dec_e_along_normals']
            dec_e_vt_normals = dec_out_dict['dec_e_vt_normals']
            dec_d = dec_out_dict['dec_d']
            rel_vel_dec = dec_out_dict['rel_vel_dec']
            # e_disp_rel_to_base_along_normals, e_disp_rel_to_baes_vt_normals
            # rel_base_pts_to_rhand_joints: bsz x ws x nnj x nnb x 3 #
            basejtse_along_normals_pred_loss = torch.sum(
                (e_disp_rel_to_base_along_normals.unsqueeze(-1) - dec_e_along_normals.unsqueeze(-1)) ** 2, dim=-1
            ).mean(dim=-1).mean(dim=-1).mean(dim=-1)
            basejtse_vt_normals_pred_loss = torch.sum(
                (e_disp_rel_to_baes_vt_normals.unsqueeze(-1) - dec_e_vt_normals.unsqueeze(-1)) ** 2, dim=-1
            ).mean(dim=-1).mean(dim=-1).mean(dim=-1)
            d_pred_loss = torch.sum(
                (dec_d.unsqueeze(-1) - disp_dist.unsqueeze(-1)) ** 2, dim=-1
            ).mean(dim=-1).mean(dim=-1).mean(dim=-1)
            rel_vel_pred_loss = torch.sum(
                (rel_vel_dec.unsqueeze(-1) - vel_obj_pts_to_hand_pts.unsqueeze(-1)) ** 2, dim=-1
            ).mean(dim=-1).mean(dim=-1).mean(dim=-1)
            
            # basejtse_latent_denoising_loss = (torch.sum(
            #     (basejtse_seq_latents.permute(1, 0, 2).contiguous() - dec_base_jts_e_feats.permute(1, 0, 2).contiguous()) ** 2, dim=-1
            # ) / basejtse_seq_latents.size(-1)).mean(dim=-1)
            
            
            if self.args.pred_diff_noise:
                # noise_joint_seq_latents 
                basejtse_latent_denoising_loss = (torch.sum(
                    (basejtse_seq_latents.permute(1, 0, 2).contiguous() - noise_basejtse_seq_latents.permute(1, 0, 2).contiguous()) ** 2, dim=-1
                ) / basejtse_seq_latents.size(-1)).mean(dim=-1)
            else:
                basejtse_latent_denoising_loss = (torch.sum(
                    (basejtse_seq_latents.permute(1, 0, 2).contiguous() - dec_base_jts_e_feats.permute(1, 0, 2).contiguous()) ** 2, dim=-1
                ) / basejtse_seq_latents.size(-1)).mean(dim=-1)
            
            # find out kong #
            
            
            if self.args.train_enc:
                basejtse_latent_denoising_loss = torch.zeros_like(basejtse_latent_denoising_loss)
            
            if self.args.train_diff: # train_diff # no basejtse denoising losses ##
                # basejtse_latent_denoising_loss = torch.zeros_like(basejtse_latent_denoising_loss)
                basejtse_along_normals_pred_loss = torch.zeros_like(basejtse_along_normals_pred_loss)
                basejtse_vt_normals_pred_loss = torch.zeros_like(basejtse_vt_normals_pred_loss)
                d_pred_loss = torch.zeros_like(d_pred_loss)
                rel_vel_pred_loss = torch.zeros_like(rel_vel_pred_loss)
            
            terms['basejtse_along_normals_pred_loss'] = basejtse_along_normals_pred_loss
            terms['basejtse_vt_normals_pred_loss'] = basejtse_vt_normals_pred_loss
            terms['basejtse_latent_denoising_loss'] = basejtse_latent_denoising_loss
            terms['d_pred_loss'] = d_pred_loss
            terms['rel_vel_pred_loss'] = rel_vel_pred_loss
            # terms['rot_mse'] += basejtse_along_normals_pred_loss * 20 + basejtse_vt_normals_pred_loss * 20 + basejtse_latent_denoising_loss
            
            terms['rot_mse'] += basejtse_along_normals_pred_loss * self.args.basejtse_along_normal_loss_coeff + basejtse_vt_normals_pred_loss * self.args.basejtse_vt_normal_loss_coeff + basejtse_latent_denoising_loss + d_pred_loss * self.basejtse_along_normal_loss_coeff + rel_vel_pred_loss * self.basejtse_along_normal_loss_coeff
        
        if self.args.kl_weights > 0.  and not self.args.train_diff:
            terms['KL_loss'] = KL_loss
            terms['rot_mse'] += KL_loss * self.args.kl_weights
            
        
        import os
        import datetime
        cur_time_stamp = datetime.datetime.now().timestamp()
        cur_time_stamp = str(cur_time_stamp)
        target_xyz, model_output_xyz = None, None


        terms["loss"] = terms["rot_mse"] 
        # else:
        #     raise NotImplementedError(self.loss_type)

        return terms


    ## training losses 
    def predict_sample_single_step(self, model, x_start, t, model_kwargs=None, noise=None, dataset=None):
        """ # s ## predict sa
        Compute training losses for a single timestep.

        :param model: the model to evaluate loss on.
        :param x_start: the [N x C x ...] tensor of inputs.
        :param t: a batch of timestep indices.
        :param model_kwargs: if not None, a dict of extra keyword arguments to
            pass to the model. This can be used for conditioning.
        :param noise: if specified, the specific Gaussian noise to try to remove.
        :return: a dict with the key "loss" containing a tensor of shape [N].
                 Some mean or variance settings may also have other keys.
        """

        # enc = model.model._modules['module']
        # enc = model.model
        
        mask = model_kwargs['y']['mask'] ## rot2xyz; #
        # ### avg_joints, std_joints ### #
        if 'avg_joints' in model_kwargs['y']:
            avg_joints = model_kwargs['y']['avg_joints'].unsqueeze(-1)
            std_joints = model_kwargs['y']['std_joints'].unsqueeze(-1).unsqueeze(-1)
        else:
            avg_joints = None
            std_joints = None
        # ### avg_joints, std_joints ### #
        # get_xyz = lambda sample: enc.rot2xyz(sample, mask=None, pose_rep=enc.pose_rep, translation=enc.translation,
        #                                      glob=enc.glob, ## rot2xyz; ##
        #                                      # jointstype='vertices',  # 3.4 iter/sec # USED ALSO IN MotionCLIP
        #                                      jointstype='smpl',  # 3.4 iter/sec
        #                                      vertstrans=False)

        if model_kwargs is None: ## predict single steps --> 
            model_kwargs = {}
        if noise is None:
            noise = th.randn_like(x_start)
        x_t = self.q_sample(x_start, t, noise=noise)
        # randn_like for x_start, t, x_t --> get x_t from x_start #
        # how we control the tiem stamp t? 

        terms = {}

        if self.loss_type == LossType.KL or self.loss_type == LossType.RESCALED_KL:
            terms["loss"] = self._vb_terms_bpd( ## vb terms bpd #
                model=model,
                x_start=x_start,
                x_t=x_t,
                t=t,
                clip_denoised=False,
                model_kwargs=model_kwargs,
            )["output"]
            if self.loss_type == LossType.RESCALED_KL:
                terms["loss"] *= self.num_timesteps
        elif self.loss_type == LossType.MSE or self.loss_type == LossType.RESCALED_MSE:
            model_output = model(x_t, self._scale_timesteps(t), **model_kwargs)
            # model_output ---> model x_t #
            if self.model_var_type in [
                ModelVarType.LEARNED,
                ModelVarType.LEARNED_RANGE,
            ]: # s
                B, C = x_t.shape[:2]
                assert model_output.shape == (B, C * 2, *x_t.shape[2:])
                model_output, model_var_values = th.split(model_output, C, dim=1)
                # Learn the variance using the variational bound, but don't let
                # it affect our mean prediction.
                frozen_out = th.cat([model_output.detach(), model_var_values], dim=1)
                terms["vb"] = self._vb_terms_bpd(
                    model=lambda *args, r=frozen_out: r,
                    x_start=x_start,
                    x_t=x_t,
                    t=t,
                    clip_denoised=False,
                )["output"]
                if self.loss_type == LossType.RESCALED_MSE:
                    # Divide by 1000 for equivalence with initial implementation.
                    # Without a factor of 1/1000, the VB term hurts the MSE term.
                    terms["vb"] *= self.num_timesteps / 1000.0

            target = {
                # q posterior mean variance #
                ModelMeanType.PREVIOUS_X: self.q_posterior_mean_variance(
                    x_start=x_start, x_t=x_t, t=t
                )[0],
                ModelMeanType.START_X: x_start,
                ModelMeanType.EPSILON: noise,
            }[self.model_mean_type] # model mean type --> mean type #
            assert model_output.shape == target.shape == x_start.shape  # [bs, njoints, nfeats, nframes]

            terms["rot_mse"] = self.masked_l2(target, model_output, mask) # mean_flat(rot_mse)
            # model_output, target, t
            
            # ### avg_joints, std_joints ### #
            if avg_joints is not None:
                print(f"model_output: {model_output.size()}, target: {target.size()}, std_joints: {std_joints.size()}, avg_joints: {avg_joints.size()}")
                print(f"Denormalizing joints...")
                model_output = (model_output * std_joints) + avg_joints
                target = (target * std_joints) + avg_joints
            
            sv_out_in = { # 
                'model_output': model_output.detach().cpu().numpy(),
                'target': target.detach().cpu().numpy(),
                't': t.detach().cpu().numpy(),
            }
            import os
            import datetime
            cur_time_stamp = datetime.datetime.now().timestamp()
            cur_time_stamp = str(cur_time_stamp)
            sv_dir_rt = "."
            sv_out_fn = os.path.join(sv_dir_rt, f"out_{cur_time_stamp}.npy")
            np.save(sv_out_fn,sv_out_in )
            print(f"Samples saved to {sv_out_fn}")
            

            target_xyz, model_output_xyz = None, None
            self.lambda_rcxyz = 0.
            if self.lambda_rcxyz > 0.:
                target_xyz = get_xyz(target)  # [bs, nvertices(vertices)/njoints(smpl), 3, nframes]
                model_output_xyz = get_xyz(model_output)  # [bs, nvertices, 3, nframes]
                terms["rcxyz_mse"] = self.masked_l2(target_xyz, model_output_xyz, mask)  # mean_flat((target_xyz - model_output_xyz) ** 2)

            if self.lambda_vel_rcxyz > 0.:
                if self.data_rep == 'rot6d' and dataset.dataname in ['humanact12', 'uestc']:
                    target_xyz = get_xyz(target) if target_xyz is None else target_xyz
                    model_output_xyz = get_xyz(model_output) if model_output_xyz is None else model_output_xyz
                    target_xyz_vel = (target_xyz[:, :, :, 1:] - target_xyz[:, :, :, :-1])
                    model_output_xyz_vel = (model_output_xyz[:, :, :, 1:] - model_output_xyz[:, :, :, :-1])
                    terms["vel_xyz_mse"] = self.masked_l2(target_xyz_vel, model_output_xyz_vel, mask[:, :, :, 1:])

            if self.lambda_fc > 0.: ## lambda fc ##
                torch.autograd.set_detect_anomaly(True)
                if self.data_rep == 'rot6d' and dataset.dataname in ['humanact12', 'uestc']:
                    target_xyz = get_xyz(target) if target_xyz is None else target_xyz
                    model_output_xyz = get_xyz(model_output) if model_output_xyz is None else model_output_xyz
                    # 'L_Ankle',  # 7, 'R_Ankle',  # 8 , 'L_Foot',  # 10, 'R_Foot',  # 11
                    l_ankle_idx, r_ankle_idx, l_foot_idx, r_foot_idx = 7, 8, 10, 11
                    relevant_joints = [l_ankle_idx, l_foot_idx, r_ankle_idx, r_foot_idx]
                    gt_joint_xyz = target_xyz[:, relevant_joints, :, :]  # [BatchSize, 4, 3, Frames]
                    gt_joint_vel = torch.linalg.norm(gt_joint_xyz[:, :, :, 1:] - gt_joint_xyz[:, :, :, :-1], axis=2)  # [BatchSize, 4, Frames]
                    fc_mask = torch.unsqueeze((gt_joint_vel <= 0.01), dim=2).repeat(1, 1, 3, 1)
                    pred_joint_xyz = model_output_xyz[:, relevant_joints, :, :]  # [BatchSize, 4, 3, Frames]
                    pred_vel = pred_joint_xyz[:, :, :, 1:] - pred_joint_xyz[:, :, :, :-1]
                    pred_vel[~fc_mask] = 0
                    terms["fc"] = self.masked_l2(pred_vel,
                                                 torch.zeros(pred_vel.shape, device=pred_vel.device),
                                                 mask[:, :, :, 1:])
            if self.lambda_vel > 0.:
                target_vel = (target[..., 1:] - target[..., :-1])
                model_output_vel = (model_output[..., 1:] - model_output[..., :-1])
                terms["vel_mse"] = self.masked_l2(target_vel[:, :-1, :, :], # Remove last joint, is the root location!
                                                  model_output_vel[:, :-1, :, :],
                                                  mask[:, :, :, 1:])  # mean_flat((target_vel - model_output_vel) ** 2)

            terms["loss"] = terms["rot_mse"] + terms.get('vb', 0.) +\
                            (self.lambda_vel * terms.get('vel_mse', 0.)) +\
                            (self.lambda_rcxyz * terms.get('rcxyz_mse', 0.)) + \
                            (self.lambda_fc * terms.get('fc', 0.))

        else:
            raise NotImplementedError(self.loss_type)

        return terms,  model_output, target, t



### for motiondiff and spatialdiff ##
class GaussianDiffusionV5:
    """
    Utilities for training and sampling diffusion models.

    Ported directly from here, and then adapted over time to further experimentation.
    https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/diffusion_utils_2.py#L42

    :param betas: a 1-D numpy array of betas for each diffusion timestep, ## pass floating point timesteps ##
                  starting at T and going to 1.
    :param model_mean_type: a ModelMeanType determining what the model outputs.
    :param model_var_type: a ModelVarType determining how variance is output.
    :param loss_type: a LossType determining the loss function to use.
    :param rescale_timesteps: if True, pass floating point timesteps into the
                              model so that they are always scaled like in the
                              original paper (0 to 1000).
    """

    def __init__(
        self,
        *,
        betas,
        model_mean_type,
        model_var_type,
        loss_type,
        rescale_timesteps=False,
        lambda_rcxyz=0.,
        lambda_vel=0.,
        lambda_pose=1.,
        lambda_orient=1.,
        lambda_loc=1.,
        data_rep='rot6d',
        lambda_root_vel=0.,
        lambda_vel_rcxyz=0.,
        lambda_fc=0.,
        denoising_stra="rep",
        inter_optim=False,
        args=None,
    ):
        self.model_mean_type = model_mean_type
        self.model_var_type = model_var_type ## model var type ##
        self.loss_type = loss_type
        self.rescale_timesteps = rescale_timesteps
        self.data_rep = data_rep
        
        self.args = args # possibly None
        
        ### GET the diff. suit ###
        self.diff_jts = self.args.diff_jts
        self.diff_basejtsrel = self.args.diff_basejtsrel
        self.diff_basejtse = self.args.diff_basejtse
        self.diff_realbasejtsrel = self.args.diff_realbasejtsrel
        ### GET the diff. suit ###

        if data_rep != 'rot_vel' and lambda_pose != 1.:
            raise ValueError('lambda_pose is relevant only when training on velocities!')
        self.lambda_pose = lambda_pose
        self.lambda_orient = lambda_orient
        self.lambda_loc = lambda_loc

        self.lambda_rcxyz = lambda_rcxyz
        self.lambda_vel = lambda_vel
        self.lambda_root_vel = lambda_root_vel
        self.lambda_vel_rcxyz = lambda_vel_rcxyz
        self.lambda_fc = lambda_fc
        
        ### === denoising_stra for the denoising process === ###
        self.denoising_stra = denoising_stra
        self.inter_optim = inter_optim

        if self.lambda_rcxyz > 0. or self.lambda_vel > 0. or self.lambda_root_vel > 0. or \
                self.lambda_vel_rcxyz > 0. or self.lambda_fc > 0.:
            assert self.loss_type == LossType.MSE, 'Geometric losses are supported by MSE loss type only!'

        self.var_sched = VarianceSchedule(len(betas), torch.tensor(betas, dtype=torch.float64))

        # Use float64 for accuracy.
        betas = np.array(betas, dtype=np.float64)
        self.betas = betas
        assert len(betas.shape) == 1, "betas must be 1-D" ## betas 
        assert (betas > 0).all() and (betas <= 1).all()

        self.num_timesteps = int(betas.shape[0])

        alphas = 1.0 - betas
        self.alphas_cumprod = np.cumprod(alphas, axis=0)
        self.alphas_cumprod_prev = np.append(1.0, self.alphas_cumprod[:-1]) ## alphas cumprod
        self.alphas_cumprod_next = np.append(self.alphas_cumprod[1:], 0.0)
        assert self.alphas_cumprod_prev.shape == (self.num_timesteps,)

        # calculations for diffusion q(x_t | x_{t-1}) and others
        self.sqrt_alphas_cumprod = np.sqrt(self.alphas_cumprod)
        self.sqrt_one_minus_alphas_cumprod = np.sqrt(1.0 - self.alphas_cumprod)
        self.log_one_minus_alphas_cumprod = np.log(1.0 - self.alphas_cumprod)
        self.sqrt_recip_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod)
        self.sqrt_recipm1_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod - 1)

        # calculations for posterior q(x_{t-1} | x_t, x_0)
        self.posterior_variance = (
            betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
        )
        # log calculation clipped because the posterior variance is 0 at the
        # beginning of the diffusion chain.
        self.posterior_log_variance_clipped = np.log(
            np.append(self.posterior_variance[1], self.posterior_variance[1:])
        )
        self.posterior_mean_coef1 = (
            betas * np.sqrt(self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
        )
        self.posterior_mean_coef2 = ( # posterior mean coefs #
            (1.0 - self.alphas_cumprod_prev)
            * np.sqrt(alphas)
            / (1.0 - self.alphas_cumprod)
        )

        self.l2_loss = lambda a, b: (a - b) ** 2  # th.nn.MSELoss(reduction='none')  # must be None for handling mask later on.
        

    def masked_l2(self, a, b, mask):
        # assuming a.shape == b.shape == bs, J, Jdim, seqlen
        # assuming mask.shape == bs, 1, 1, seqlen
        loss = self.l2_loss(a, b)
        loss = sum_flat(loss * mask.float())  # gives \sigma_euclidean over unmasked elements
        n_entries = a.shape[1] * a.shape[2]
        non_zero_elements = sum_flat(mask) * n_entries ## non-zero-elements
        # print('mask', mask.shape)
        # print('non_zero_elements', non_zero_elements)
        # print('loss', loss)
        mse_loss_val = loss / non_zero_elements
        # print('mse_loss_val', mse_loss_val)
        return mse_loss_val


    def q_mean_variance(self, x_start, t):
        """
        Get the distribution q(x_t | x_0). # q-mean-variance #

        :param x_start: the [N x C x ...] tensor of noiseless inputs.
        :param t: the number of diffusion steps (minus 1). Here, 0 means one step.
        :return: A tuple (mean, variance, log_variance), all of x_start's shape.
        """
        mean = (
            _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
        )
        variance = _extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape)
        log_variance = _extract_into_tensor(
            self.log_one_minus_alphas_cumprod, t, x_start.shape
        )
        return mean, variance, log_variance

    def q_sample(self, x_start, t, noise=None):
        """
        Diffuse the dataset for a given number of diffusion steps.

        In other words, sample from q(x_t | x_0). ## q pos

        :param x_start: the initial dataset batch.
        :param t: the number of diffusion steps (minus 1). Here, 0 means one step.
        :param noise: if specified, the split-out normal noise.
        :return: A noisy version of x_start.
        """
        if noise is None:
            noise = th.randn_like(x_start)
        assert noise.shape == x_start.shape
        return (
            _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
            + _extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape)
            * noise ### variance xxx noise ###
        )

    ## q_sample, 
    def q_posterior_mean_variance(self, x_start, x_t, t): ## q_posterior
        """
        Compute the mean and variance of the diffusion posterior:
            q(x_{t-1} | x_t, x_0)
        """
        assert x_start.shape == x_t.shape
        posterior_mean = ( # posterior mean and variance #
            _extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start
            + _extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t
        )
        posterior_variance = _extract_into_tensor(self.posterior_variance, t, x_t.shape)
        posterior_log_variance_clipped = _extract_into_tensor(
            self.posterior_log_variance_clipped, t, x_t.shape
        )
        assert (
            posterior_mean.shape[0]
            == posterior_variance.shape[0]
            == posterior_log_variance_clipped.shape[0]
            == x_start.shape[0]
        )
        return posterior_mean, posterior_variance, posterior_log_variance_clipped
    # phy project pred joints; phy predict joints here # 
    def phy_projct_pred_joints(self, pred_joints, base_pts, base_normals):
        # pred_joints: bsz x nf x nn_jts x 3 # # pred joints #
        # base_pts: bsz x nn_base_pts x 3 #
        # base_normals: bsz x nn_base_pts x 3 #
        nf = pred_joints.size(1)
        if not self.args.use_arti_obj:
            base_pts_exp = base_pts.unsqueeze(1).repeat(1, nf, 1, 1).contiguous()
            base_normals_exp = base_normals.unsqueeze(1).repeat(1, nf, 1, 1).contiguous()
        else:
            base_pts_exp = base_pts.clone()
            base_normals_exp = base_normals.clone()
        
        nearest_pred_joints_to_base_pts = torch.sum(
            (pred_joints.unsqueeze(-2) - base_pts_exp.unsqueeze(2)) ** 2, dim=-1
        )
        nearest_dist, nearest_base_pts_idxes = torch.min(nearest_pred_joints_to_base_pts, dim=-1) # bsz x nf x nn_jts 
        nearest_dist = torch.sqrt(nearest_dist)
        nearest_base_pts = model_util.batched_index_select_ours(base_pts_exp, nearest_base_pts_idxes, dim=2) # bsz x nf x nn_jts x 3 
        nearest_base_normals = model_util.batched_index_select_ours(base_normals_exp, nearest_base_pts_idxes, dim=2) # bsz x nf x nn_jts x 3 # 
        jts_to_base_pts = pred_joints - nearest_base_pts # from base pts to pred joints #
        dot_rel_with_normals = torch.sum(
            jts_to_base_pts * nearest_base_normals, dim=-1 # bsz x nf x nn_jts --> joints inside of the object #
        )
        jts_proj_dir = torch.zeros_like(nearest_base_pts) # bsz x nf x nn_jts x 3 # 
        jts_proj_dir[dot_rel_with_normals < 0.] =  jts_to_base_pts[dot_rel_with_normals < 0.] # bsz x nf x nn_jts x 3 # 
        return jts_proj_dir # bsz x nf x nn_jts x 3 # returned gradients # 

    # # the full physical world here? ## 
    def p_mean_variance_cond( ## get mean data ##
        self, model, x, t, clip_denoised=True, denoised_fn=None, model_kwargs=None
    ):
        """
        Apply the model to get p(x_{t-1} | x_t), as well as a prediction of
        the initial x, x_0.

        :param model: the model, which takes a signal and a batch of timesteps
                      as input.
        :param x: the [N x C x ...] tensor at time t.
        :param t: a 1-D Tensor of timesteps.
        :param clip_denoised: if True, clip the denoised signal into [-1, 1]. # # clip the denoised signal ##
        :param denoised_fn: if not None, a function which applies to the
            x_start prediction before it is used to sample. Applies before
            clip_denoised. # denoised fn
        :param model_kwargs: if not None, a dict of extra keyword arguments to
            pass to the model. This can be used for conditioning.
        :return: a dict with the following keys:
                 - 'mean': the model mean output.
                 - 'variance': the model variance output.
                 - 'log_variance': the log of 'variance'.
                 - 'pred_xstart': the prediction for x_0.
        """ # p_mean_varaince #
        if model_kwargs is None:
            model_kwargs = {}
            
        B = x['base_pts'].shape[0]
        assert t.shape == (B,)
        
        # print(f"t_shape: {t.shape}", "base_pts", x['base_pts'].size())
        
        input_data = x
        
        out_dict = model(input_data, self._scale_timesteps(t).clone()) 
        
        rt_dict = {}
        real_basejtsrel_seq_rt_dict = {}
        basejtsrel_seq_rt_dict = {}
        realbasejtsrel_to_joints_rt_dict = {}

        model_variance, model_log_variance = self.posterior_variance, self.posterior_log_variance_clipped

        if self.diff_realbasejtsrel and self.diff_basejtsrel: # diff_realbasejtsrel, real_basejtsrel_seq_rt_dict
            
            pert_rel_base_pts_outputs = x['pert_joints_offset_sequence']
            basejtsrel_output = out_dict['joints_offset_output']
            score_jts = basejtsrel_output
            
            real_dec_basejtsrel = out_dict['real_dec_basejtsrel'] # bsz x nf x nnj x nnb x 3
            pert_rel_base_pts_to_rhand_joints = x['pert_rel_base_pts_to_rhand_joints']
            
            
            # combine those two things #
            if self.args.pred_diff_noise and not self.args.train_enc:
                if self.args.add_noise_onjts: # add noise onjts #
                    if self.args.real_basejtsrel_norm_stra == "std": ### std
                        denormed_rel_base_pts_to_rhand_joints = pert_rel_base_pts_to_rhand_joints * x['std_rel_base_pts_to_rhand_joints'] + x['avg_rel_base_pts_to_rhand_joints']
                    else:
                        denormed_rel_base_pts_to_rhand_joints = pert_rel_base_pts_to_rhand_joints
                    # jts_fr_basepts = denormed_rel_base_pts_to_rhand_joints + x['normed_base_pts'].unsqueeze(1).unsqueeze(1)
                    # jts_fr_basepts = jts_fr_basepts.mean(dim=-2)
                    jts_fr_basepts = pert_rel_base_pts_outputs # pert
                    
                    # score_jts_fr_basepts = real_dec_basejtsrel.mean(dim=-2) # bsz x seq x nnjts x 3
                    
                    score_jts_fr_basepts = real_dec_basejtsrel[..., self.args.sel_basepts_idx, :]
                    
                    t_item = t[0].item()
                    alpha_bar = self.var_sched.alpha_bars[t_item]
                    
                    # combined_socre = score_jts - torch.sqrt(1. - alpha_bar) * score_jts_fr_basepts
                    # combined_socre = score_jts # - torch.sqrt(1. - alpha_bar) * socre_jts_fr_basepts
                    # combined_socre = score_jts_fr_basepts
                    
                    if self.args.only_cmb_finger:
                        combined_socre = score_jts
                        # combined_socre[..., -5:, :] = combined_socre[..., -5:, :] - (torch.sqrt(1. - alpha_bar) * score_jts_fr_basepts)[..., -5:, :]
                        # combined_socre[..., -5:, :] = combined_socre[..., -5:, :] * 0.7 + score_jts_fr_basepts[..., -5:, :] * 0.3
                        # combined_socre = combined_socre * 0.2 + score_jts_fr_basepts * 0.8
                        combined_socre = combined_socre * 0.1 + score_jts_fr_basepts * 0.9
                        # combined_socre = combined_socre * 0.05 + score_jts_fr_basepts * 0.95
                        # combined_socre = combined_socre * 0.5 + score_jts_fr_basepts * 0.5
                        # combined_socre = combined_socre 
                        # combined_socre = score_jts_fr_basepts 
                    else:
                        combined_socre = score_jts - torch.sqrt(1. - alpha_bar) * score_jts_fr_basepts
                        combined_socre = score_jts # not cmb finger #
                        # combined_socre = score_jts_fr_basepts 
                    
                    if self.args.use_var_sched:
                        bsz = real_dec_basejtsrel.size(0)
                        t_item = t[0].item()
                        alpha = self.var_sched.alphas[t_item]
                        alpha_bar = self.var_sched.alpha_bars[t_item]
                        sigma = self.var_sched.get_sigmas(t_item, 0.) # sigma #
                        
                        # sigma = sigma / 2. # use_var_sched -> #
                        c0 = 1.0 / torch.sqrt(alpha)
                        c1 = (1 - alpha) / torch.sqrt(1 - alpha_bar)
                        
                        beta = self.var_sched.betas[[t[0].item()] * bsz]
                        
                        z = torch.randn_like(jts_fr_basepts) if t_item > 0 else torch.zeros_like(jts_fr_basepts)
                        dec_jts_fr_basepts = c0 * (jts_fr_basepts - c1 * combined_socre) + sigma * z # theta
                    else: # x_{t-1}
                        dec_jts_fr_basepts = self._predict_xstart_from_eps(jts_fr_basepts, t=t, eps=combined_socre)
                    if not self.args.use_arti_obj:
                        real_dec_basejtsrel = dec_jts_fr_basepts.unsqueeze(-2) - x['normed_base_pts'].unsqueeze(1).unsqueeze(1)
                    else:
                        real_dec_basejtsrel = dec_jts_fr_basepts.unsqueeze(-2) - x['normed_base_pts'].unsqueeze(2)
                    # print(f"dec_jts_fr_basepts: {dec_jts_fr_basepts.size()}, normed_base_pts: ", x['normed_base_pts'].size(), "real_dec_basejtsrel:", real_dec_basejtsrel.size())
                    if self.args.real_basejtsrel_norm_stra == "std": 
                        real_dec_basejtsrel = (real_dec_basejtsrel - x['avg_rel_base_pts_to_rhand_joints']) / x['std_rel_base_pts_to_rhand_joints']
                        
                    # print(f"real_dec_basejtsrel: {real_dec_basejtsrel.size()}, denormed_rel_base_pts_to_rhand_joints: {denormed_rel_base_pts_to_rhand_joints.size()}, jts_fr_basepts: {jts_fr_basepts.size()}")

                elif self.args.add_noise_onjts_single: # add noise on single joint
                    if not self.args.use_arti_obj:
                        jts_fr_basepts = pert_rel_base_pts_to_rhand_joints + x['normed_base_pts'].unsqueeze(1).unsqueeze(1)
                    else:
                        jts_fr_basepts = pert_rel_base_pts_to_rhand_joints + x['normed_base_pts'].unsqueeze(2)
                    jts_fr_basepts = jts_fr_basepts.mean(dim=-2)
                    dec_jts_fr_basepts = real_dec_basejtsrel.mean(dim=-2) # + x['normed_base_pts'].unsqueeze(1).unsqueeze(1) # rel dec basejtsrel #
                    # dec_jts_fr_basepts = real_dec_basejtsrel[..., 0, :]
                    t_item = t[0].item()
                    alpha_bar = self.var_sched.alpha_bars[t_item]
                    socre_jts_fr_basepts = dec_jts_fr_basepts
                    
                    if self.args.only_cmb_finger:
                        combined_socre = score_jts
                        # strategy 1 --> conditioning #
                        # combined_socre[..., -5:, :] = combined_socre[..., -5:, :] - (torch.sqrt(1. - alpha_bar) * socre_jts_fr_basepts)[..., -5:, :]
                        # strategy 2 --> linear interpolation #
                        # combined_socre[..., -5:, :] = combined_socre[..., -5:, :] * 0.7 + socre_jts_fr_basepts[..., -5:, :] * 0.3
                        combined_socre[..., -5:, :] = combined_socre[..., -5:, :] * 0.5 + socre_jts_fr_basepts[..., -5:, :] * 0.5
                    else:
                        combined_socre = score_jts - torch.sqrt(1. - alpha_bar) * socre_jts_fr_basepts
                        combined_socre = score_jts # - torch.sqrt(1. - alpha_bar) * socre_jts_fr_basepts
                    
                    # combined_socre = socre_jts_fr_basepts
                    if self.args.use_var_sched:
                        bsz = dec_jts_fr_basepts.size(0)
                        
                        alpha = self.var_sched.alphas[t_item]
                        alpha_bar = self.var_sched.alpha_bars[t_item]
                        sigma = self.var_sched.get_sigmas(t_item, 0.) # sigma #

                        c0 = 1.0 / torch.sqrt(alpha)
                        c1 = (1 - alpha) / torch.sqrt(1 - alpha_bar)
                        
                        beta = self.var_sched.betas[[t[0].item()] * bsz]
                        z = torch.randn_like(combined_socre) if t_item > 0 else torch.zeros_like(jts_fr_basepts)
                        dec_jts_fr_basepts = c0 * (pert_rel_base_pts_outputs - c1 * combined_socre) + sigma * z # theta # bsz x seq x nnjts x nnbase x 3 #
                    else:
                        dec_jts_fr_basepts = self._predict_xstart_from_eps(pert_rel_base_pts_outputs, t=t, eps=combined_socre)
                    if not self.args.use_arti_obj:
                        real_dec_basejtsrel = dec_jts_fr_basepts.unsqueeze(-2) - x['normed_base_pts'].unsqueeze(1).unsqueeze(1)
                    else:
                        real_dec_basejtsrel = dec_jts_fr_basepts.unsqueeze(-2) - x['normed_base_pts'].unsqueeze(2)
                else:
                    # raise ValueError(f"Add noise directly --- not implemented yet")
                    # # input_data
                    # self.args.real_basejtsrel_norm_stra == "std":
                        # avg_rel_base_pts_to_rhand_joints, std_rel_base_pts_to_rhand_joints
                    if self.args.real_basejtsrel_norm_stra == "std": ### std
                        denormed_rel_base_pts_to_rhand_joints = pert_rel_base_pts_to_rhand_joints * x['std_rel_base_pts_to_rhand_joints'] + x['avg_rel_base_pts_to_rhand_joints']
                    else:
                        denormed_rel_base_pts_to_rhand_joints = pert_rel_base_pts_to_rhand_joints
                    if not self.args.use_arti_obj:
                        jts_fr_basepts = denormed_rel_base_pts_to_rhand_joints + x['normed_base_pts'].unsqueeze(1).unsqueeze(1)
                    else:
                        jts_fr_basepts = denormed_rel_base_pts_to_rhand_joints + x['normed_base_pts'].unsqueeze(2)
                    jts_fr_basepts = jts_fr_basepts.mean(dim=-2)
                    score_jts_fr_basepts = real_dec_basejtsrel.mean(dim=-2) # bsz x seq x nnjts x 3
                    t_item = t[0].item()
                    alpha_bar = self.var_sched.alpha_bars[t_item]
                    
                    combined_socre = score_jts - torch.sqrt(1. - alpha_bar) * socre_jts_fr_basepts
                    combined_socre = score_jts # - torch.sqrt(1. - alpha_bar) * socre_jts_fr_basepts
                    combined_socre = score_jts_fr_basepts
                    
                    if self.args.use_var_sched:
                        bsz = real_dec_basejtsrel.size(0)
                        t_item = t[0].item()
                        alpha = self.var_sched.alphas[t_item]
                        alpha_bar = self.var_sched.alpha_bars[t_item]
                        sigma = self.var_sched.get_sigmas(t_item, 0.) # sigma #
                        
                        # sigma = sigma / 2.

                        c0 = 1.0 / torch.sqrt(alpha)
                        c1 = (1 - alpha) / torch.sqrt(1 - alpha_bar)
                        
                        beta = self.var_sched.betas[[t[0].item()] * bsz]
                        
                        z = torch.randn_like(jts_fr_basepts) if t_item > 0 else torch.zeros_like(jts_fr_basepts)
                        dec_jts_fr_basepts = c0 * (jts_fr_basepts - c1 * combined_socre) + sigma * z # theta
                    else: # x_{t-1}
                        dec_jts_fr_basepts = self._predict_xstart_from_eps(jts_fr_basepts, t=t, eps=combined_socre)
                    if not self.args.use_arti_obj:
                        real_dec_basejtsrel = dec_jts_fr_basepts.unsqueeze(-2) - x['normed_base_pts'].unsqueeze(1).unsqueeze(1)
                    else:
                        real_dec_basejtsrel = dec_jts_fr_basepts.unsqueeze(-2) - x['normed_base_pts'].unsqueeze(2)
                    if self.args.real_basejtsrel_norm_stra == "std": 
                        real_dec_basejtsrel = (real_dec_basejtsrel - x['avg_rel_base_pts_to_rhand_joints']) / x['std_rel_base_pts_to_rhand_joints']

            # diff_realbasejtsrel
            real_basejtsrel_seq_rt_dict = {
                "real_dec_basejtsrel": real_dec_basejtsrel,
            }
            
            basejtsrel_seq_rt_dict = {
                "basejtsrel_output": dec_jts_fr_basepts
            }
            
            
        if self.diff_basejtsrel and self.args.diff_realbasejtsrel_to_joints:
            pert_rel_base_pts_outputs = x['pert_joints_offset_sequence']
            basejtsrel_output = out_dict['joints_offset_output']
            score_jts = basejtsrel_output
            
            pert_joints_offset_output = x['pert_joints_offset_sequence'] # rel base pts outputs #
            dec_joints_offset_output = out_dict['joints_offset_output_from_rel'] # joints offset output #
            score_jts_fr_rel = dec_joints_offset_output # # pert joints offset sequence #
            
            # if self.args.pred_diff_noise: ## eps -> estimated-noises # predict noise? # 
            # # if self.args.use_var_sched:
            bsz = dec_joints_offset_output.size(0) ## batch size fo the batch, pert-joints, predicted-joints ##
            t_item = t[0].item()
            
            alpha_bar = self.var_sched.alpha_bars[t_item]
            
            combined_socre = score_jts - torch.sqrt(1. - alpha_bar) * score_jts_fr_rel
            # combined_socre = score_jts # - torch.sqrt(1. - alpha_bar) * socre_jts_fr_basepts
            # combined_socre = score_jts_fr_rel
            
            alpha = self.var_sched.alphas[t_item]
            alpha_bar = self.var_sched.alpha_bars[t_item]
            sigma = self.var_sched.get_sigmas(t_item, 0.) # sigma # # sigma #

            c0 = 1.0 / torch.sqrt(alpha) ### c0? c1? --> 
            c1 = (1 - alpha) / torch.sqrt(1 - alpha_bar)
            
            beta = self.var_sched.betas[[t[0].item()] * bsz]
            z = torch.randn_like(dec_joints_offset_output) if t_item > 0 else torch.zeros_like(dec_joints_offset_output) # dec_joints 
            dec_joints_offset_output = c0 * (pert_joints_offset_output - c1 * score_jts) + sigma * z # theta

            ### realjtsrel_to_joints and joints only ##
            realbasejtsrel_to_joints_rt_dict = {
                'dec_joints_offset_output': dec_joints_offset_output,
            }
            
            basejtsrel_seq_rt_dict = {
                "basejtsrel_output": dec_joints_offset_output
            }
        # else:
        #     realbasejtsrel_to_joints_rt_dict = {}
        #     basejtsrel_seq_rt_dict = {}
            
        # rt_dict.update(jts_seq_rt_dict)
        rt_dict.update(basejtsrel_seq_rt_dict)
        # rt_dict.update(basejtse_seq_rt_dict)
        rt_dict.update(real_basejtsrel_seq_rt_dict)
        rt_dict.update(realbasejtsrel_to_joints_rt_dict)
        
        return rt_dict

    def p_mean_variance( ## get mean data ##
        self, model, x, t, clip_denoised=True, denoised_fn=None, model_kwargs=None
    ):
        """
        Apply the model to get p(x_{t-1} | x_t), as well as a prediction of
        the initial x, x_0.

        :param model: the model, which takes a signal and a batch of timesteps
                      as input.
        :param x: the [N x C x ...] tensor at time t.
        :param t: a 1-D Tensor of timesteps.
        :param clip_denoised: if True, clip the denoised signal into [-1, 1]. # # clip the denoised signal ##
        :param denoised_fn: if not None, a function which applies to the
            x_start prediction before it is used to sample. Applies before
            clip_denoised. # denoised fn
        :param model_kwargs: if not None, a dict of extra keyword arguments to
            pass to the model. This can be used for conditioning.
        :return: a dict with the following keys:
                 - 'mean': the model mean output.
                 - 'variance': the model variance output.
                 - 'log_variance': the log of 'variance'.
                 - 'pred_xstart': the prediction for x_0.
        """
        if model_kwargs is None:
            model_kwargs = {}
        
        ## === version 1 -> predict x_start at the rel domain === ##
        ### == x -> formulated as model_inputs == ###
        ### == TODO: process x_start from the predicted value (relative positions, signed distances) with base_pts and baes_normals == ###
        B = x['base_pts'].shape[0]
        assert t.shape == (B,)
        
        input_data = x
        
        ## dec_out and out ## ## output dict ##
        out_dict = model(input_data, self._scale_timesteps(t).clone())
        
        rt_dict = {}
        # # }[self.model_var_type]
        # ### === model variance and log_variance === ### ## posterior_log_variance_clipped, posterior_variance ##
        model_variance, model_log_variance = self.posterior_variance, self.posterior_log_variance_clipped # pmean variance
           
            
        if self.diff_realbasejtsrel: # diff_realbasejtsrel, real_basejtsrel_seq_rt_dict
            real_dec_basejtsrel = out_dict['real_dec_basejtsrel'] # bsz x nf x nnj x nnb x 3
            pert_rel_base_pts_to_rhand_joints = x['pert_rel_base_pts_to_rhand_joints']
            
            if self.args.pred_diff_noise and not self.args.train_enc:
                if self.args.add_noise_onjts:
                    if not self.args.use_arti_obj:
                        jts_fr_basepts = pert_rel_base_pts_to_rhand_joints + x['normed_base_pts'].unsqueeze(1).unsqueeze(1)
                    else:
                        jts_fr_basepts = pert_rel_base_pts_to_rhand_joints + x['normed_base_pts'].unsqueeze(2)
                    # jts_fr_basepts = pert_rel_base_pts_to_rhand_joints[:, :, :, 0:1, :] + x['normed_base_pts'].unsqueeze(1).unsqueeze(1)
                    dec_jts_fr_basepts = real_dec_basejtsrel # + x['normed_base_pts'].unsqueeze(1).unsqueeze(1) # rel dec basejtsrel #
                    
                    ## use same noise for rep ## use noise for rep ##
                    ### use_same_noise_for_rep --> use same noise for rep ###
                    if self.args.use_same_noise_for_rep: # # convert them to the strategy of using single noise ##
                        if self.args.sel_basepts_idx >= 0: # real dec base jts rel #
                            dec_jts_fr_basepts = real_dec_basejtsrel[:, :, :, self.args.sel_basepts_idx: self.args.sel_basepts_idx + 1]
                        else:
                            dec_jts_fr_basepts = real_dec_basejtsrel.mean(dim=-2, keepdim=True) # dec noise #
                            # [:, :, :, self.args.sel_basepts_idx: self.args.sel_basepts_idx + 1]
                    # from noise and x_t to x_start; 
                    # a projection strategy for x_start; 
                    # to noise 
                    # and we want to adjust nosie 
                    # if self.args.phy_guided_sampling and t[0].item() < 1:
                    #     # phy_guided_sampling #
                    #     pred_dec_jts = self._predict_xstart_from_eps(jts_fr_basepts, t, dec_jts_fr_basepts) # bsz x nf x nn_jts x nn_base_pts x 3 
                    #     if self.args.sel_basepts_idx >= 0:
                    #         pred_dec_jts = pred_dec_jts[:, :, :, self.args.sel_basepts_idx] # bsz x nf x nn_jts x 3 #
                    #     else:
                    #         pred_dec_jts = pred_dec_jts.mean(dim=-2) # bsz x nf x nn_jts x 3 #
                    #     # phy_projct_pred_joints(self, pred_joints, base_pts, base_normals)
                    #     joints_proj_dir = self.phy_projct_pred_joints(pred_dec_jts, x['normed_base_pts'], x['base_normals']) # bsz x nf x nn_jts x 3 
                    #     x_start_projed = pred_dec_jts - joints_proj_dir
                    #     x_start_splat = x_start_projed.unsqueeze(-2) - x['normed_base_pts'].unsqueeze(1).unsqueeze(1) # bsz x nf x nn_jts x nn_base_pts x 3 ##
                    #     dec_jts_fr_basepts_projed = self._predict_eps_from_xstart(jts_fr_basepts, t, x_start_splat) # bsz x nf x nn_jts x nn_base_pts x 3 #
                        
                    #     dec_ratio = 0.95
                    #     dec_jts_fr_basepts = dec_jts_fr_basepts * dec_ratio + dec_jts_fr_basepts_projed * (1. - dec_ratio)
                        
                    
                    if self.args.use_var_sched:
                        bsz = dec_jts_fr_basepts.size(0)
                        t_item = t[0].item()
                        alpha = self.var_sched.alphas[t_item]
                        alpha_bar = self.var_sched.alpha_bars[t_item]
                        sigma = self.var_sched.get_sigmas(t_item, 0.) # sigma #

                        c0 = 1.0 / torch.sqrt(alpha)
                        c1 = (1 - alpha) / torch.sqrt(1 - alpha_bar)
                        
                        beta = self.var_sched.betas[[t[0].item()] * bsz]
                        z = torch.randn_like(dec_jts_fr_basepts) if t_item > 0 else torch.zeros_like(dec_jts_fr_basepts)
                        
                        dec_jts_fr_basepts = c0 * (jts_fr_basepts - c1 * dec_jts_fr_basepts) + sigma * z # theta # bsz x seq x nnjts x nnbase x 3 #
                        
                        # real_dec_basejtsrel = c0 * (pert_rel_base_pts_to_rhand_joints - c1 * real_dec_basejtsrel) + sigma * z # theta
                        
                        ## dec jts fr base pts ## # dec jts fr 
                        # dec_jts_fr_basepts = dec_jts_fr_basepts.mean(dim=-2).unsqueeze(-2).repeat(1, 1, 1, dec_jts_fr_basepts.size(-2), 1) # bsz x seq x nnjts x nnbase x 3 #
                    else: # x_{t-1}
                        dec_jts_fr_basepts = self._predict_xstart_from_eps(jts_fr_basepts, t=t, eps=dec_jts_fr_basepts)
                        
                        
                    if self.args.phy_guided_sampling and t[0].item() < 10:
                        # phy_guided_sampling #
                        # pred_dec_jts = self._predict_xstart_from_eps(jts_fr_basepts, t, dec_jts_fr_basepts) # bsz x nf x nn_jts x nn_base_pts x 3 
                        pred_dec_jts = dec_jts_fr_basepts.clone()
                        if self.args.sel_basepts_idx >= 0:
                            pred_dec_jts = pred_dec_jts[:, :, :, self.args.sel_basepts_idx] # bsz x nf x nn_jts x 3 #
                        else:
                            pred_dec_jts = pred_dec_jts.mean(dim=-2) # bsz x nf x nn_jts x 3 #
                        # phy_projct_pred_joints(self, pred_joints, base_pts, base_normals)
                        joints_proj_dir = self.phy_projct_pred_joints(pred_dec_jts, x['normed_base_pts'], x['base_normals']) # bsz x nf x nn_jts x 3 
                        x_start_projed = pred_dec_jts - joints_proj_dir
                        # x_start_splat = x_start_projed.unsqueeze(-2) - x['normed_base_pts'].unsqueeze(1).unsqueeze(1) # bsz x nf x nn_jts x nn_base_pts x 3 ##
                        # dec_jts_fr_basepts_projed = self._predict_eps_from_xstart(jts_fr_basepts, t, x_start_splat) # bsz x nf x nn_jts x nn_base_pts x 3 #
                        
                        dec_ratio = 0.
                        dec_jts_fr_basepts = dec_jts_fr_basepts * dec_ratio + x_start_projed.unsqueeze(-2) * (1. - dec_ratio)
                        
                        
                    if not self.args.use_arti_obj:
                        real_dec_basejtsrel = dec_jts_fr_basepts - x['normed_base_pts'].unsqueeze(1).unsqueeze(1)
                    else:
                        real_dec_basejtsrel = dec_jts_fr_basepts - x['normed_base_pts'].unsqueeze(2)
                elif self.args.add_noise_onjts_single:
                    if not self.args.use_arti_obj:
                        jts_fr_basepts = pert_rel_base_pts_to_rhand_joints + x['normed_base_pts'].unsqueeze(1).unsqueeze(1)
                    else:
                        jts_fr_basepts = pert_rel_base_pts_to_rhand_joints + x['normed_base_pts'].unsqueeze(2)
                    jts_fr_basepts = jts_fr_basepts.mean(dim=-2)
                    # dec_jts_fr_basepts = real_dec_basejtsrel.mean(dim=-2) # + x['normed_base_pts'].unsqueeze(1).unsqueeze(1) # rel dec basejtsrel #
                    dec_jts_fr_basepts = real_dec_basejtsrel[..., 0, :]
                    if self.args.use_var_sched:
                        bsz = dec_jts_fr_basepts.size(0)
                        t_item = t[0].item()
                        alpha = self.var_sched.alphas[t_item]
                        alpha_bar = self.var_sched.alpha_bars[t_item]
                        sigma = self.var_sched.get_sigmas(t_item, 0.) # sigma #

                        c0 = 1.0 / torch.sqrt(alpha)
                        c1 = (1 - alpha) / torch.sqrt(1 - alpha_bar)
                        
                        beta = self.var_sched.betas[[t[0].item()] * bsz]
                        z = torch.randn_like(jts_fr_basepts) if t_item > 0 else torch.zeros_like(jts_fr_basepts)
                        dec_jts_fr_basepts = c0 * (jts_fr_basepts - c1 * dec_jts_fr_basepts) + sigma * z # theta # bsz x seq x nnjts x nnbase x 3 #
                    else: # x_{t-1}
                        dec_jts_fr_basepts = self._predict_xstart_from_eps(jts_fr_basepts, t=t, eps=dec_jts_fr_basepts)
                    # real_dec_basejtsrel = dec_jts_fr_basepts.unsqueeze(-2) - x['normed_base_pts'].unsqueeze(1).unsqueeze(1)
                    if not self.args.use_arti_obj:
                        real_dec_basejtsrel = dec_jts_fr_basepts.unsqueeze(-2) - x['normed_base_pts'].unsqueeze(1).unsqueeze(1)
                    else:
                        real_dec_basejtsrel = dec_jts_fr_basepts.unsqueeze(-2) - x['normed_base_pts'].unsqueeze(2)
                else:
                    if self.args.use_var_sched:
                        bsz = real_dec_basejtsrel.size(0)
                        t_item = t[0].item()
                        alpha = self.var_sched.alphas[t_item]
                        alpha_bar = self.var_sched.alpha_bars[t_item]
                        sigma = self.var_sched.get_sigmas(t_item, 0.) # sigma #
                        
                        # sigma = sigma / 2.

                        c0 = 1.0 / torch.sqrt(alpha)
                        c1 = (1 - alpha) / torch.sqrt(1 - alpha_bar)
                        
                        beta = self.var_sched.betas[[t[0].item()] * bsz]
                        
                        z = torch.randn_like(pert_rel_base_pts_to_rhand_joints) if t_item > 0 else torch.zeros_like(pert_rel_base_pts_to_rhand_joints)
                        
                        # z = torch.zeros_like(pert_rel_base_pts_to_rhand_joints)
                        
                        real_dec_basejtsrel = c0 * (pert_rel_base_pts_to_rhand_joints - c1 * real_dec_basejtsrel) + sigma * z # theta
                        
                        # dec_jts_fr_basepts = real_dec_basejtsrel + x['normed_base_pts'].unsqueeze(1).unsqueeze(1) # get dec_jts fr basepts 
                        # dec_jts_fr_basepts = dec_jts_fr_basepts.mean(dim=-2).unsqueeze(-2).repeat(1, 1, 1, dec_jts_fr_basepts.size(-2), 1) # repeated basepts 
                        # real_dec_basejtsrel = dec_jts_fr_basepts - x['normed_base_pts'].unsqueeze(1).unsqueeze(1) # real dec #
                        
                    else: # x_{t-1}
                        real_dec_basejtsrel = self._predict_xstart_from_eps(pert_rel_base_pts_to_rhand_joints, t=t, eps=real_dec_basejtsrel)
            ## use the predicted latents and pert_latents for the seq latents prediction ##
            real_dec_basejtsrel_mean, _, _ = self.q_posterior_mean_variance( # q posterior 
                x_start=real_dec_basejtsrel, x_t=pert_rel_base_pts_to_rhand_joints, t=t
            )
            ## from model_variance to basejtsrel_seq_latents ###
            real_dec_basejtsrel_variance = _extract_into_tensor(model_variance, t, real_dec_basejtsrel.shape)
            real_dec_basejtsrel_log_variance = _extract_into_tensor(model_log_variance, t, real_dec_basejtsrel.shape)
            
            # diff_realbasejtsrel
            real_basejtsrel_seq_rt_dict = {
                "real_dec_basejtsrel": real_dec_basejtsrel,
                "real_dec_basejtsrel_mean": real_dec_basejtsrel_mean,
                "real_dec_basejtsrel_variance": real_dec_basejtsrel_variance,
                "real_dec_basejtsrel_log_variance": real_dec_basejtsrel_log_variance,
            }
            
            if self.args.train_enc: # pert_rel_base_pts_to_rhand_joints
                pert_obj_base_pts_feats = x['pert_obj_base_pts_feats']
                dec_obj_base_pts_feats = model.model.denoising_realbasejtsrel_objbasefeats(pert_obj_base_pts_feats, t)
                dec_obj_base_pts_feats = dec_obj_base_pts_feats.permute(1, 0, 2)
                pert_obj_base_pts_feats = pert_obj_base_pts_feats.permute(1, 0, 2)
                if self.args.pred_diff_noise:
                    bsz = dec_obj_base_pts_feats.size(0)
                    t_item = t[0].item()
                    alpha = self.var_sched.alphas[t_item]
                    alpha_bar = self.var_sched.alpha_bars[t_item]
                    sigma = self.var_sched.get_sigmas(t_item, 0.) # sigma #

                    c0 = 1.0 / torch.sqrt(alpha)
                    c1 = (1 - alpha) / torch.sqrt(1 - alpha_bar)
                    
                    beta = self.var_sched.betas[[t[0].item()] * bsz]
                    z = torch.randn_like(dec_obj_base_pts_feats) if t_item > 0 else torch.zeros_like(dec_obj_base_pts_feats)
                    dec_obj_base_pts_feats = c0 * (pert_obj_base_pts_feats - c1 * dec_obj_base_pts_feats) + sigma * z # theta
                dec_obj_base_pts_feats = dec_obj_base_pts_feats.permute(1, 0, 2)
                real_dec_basejtsrel = model.model.decode_realbasejtsrel_from_objbasefeats(dec_obj_base_pts_feats, x)
                real_basejtsrel_seq_rt_dict.update(
                    {
                        'real_dec_basejtsrel': real_dec_basejtsrel, 
                        'dec_obj_base_pts_feats': dec_obj_base_pts_feats,
                    }
                )
        else:
            real_basejtsrel_seq_rt_dict = {}
                
                # else: # x_{t-1}
                #     real_dec_basejtsrel = self._predict_xstart_from_eps(pert_rel_base_pts_to_rhand_joints, t=t, eps=real_dec_basejtsrel)
            
        
        if self.diff_basejtsrel:
            
            if 'basejtsrel_output' in out_dict:
                pert_rel_base_pts_outputs = x['pert_rel_base_pts_to_rhand_joints'] # rel base pts outputs #
                pert_avg_joints_sequence = x['pert_avg_joints_sequence']
                
                basejtsrel_output = out_dict['basejtsrel_output'].transpose(-2, -3).contiguous()
                avg_jts_outputs = out_dict['avg_jts_outputs']
                
                # if pert_rel_base_pts_outputs.size(0) == 1:
                #     pert_rel_base_pts_outputs = pert_rel_base_pts_outputs.repeat(pred_basejtsrel_seq_latents.size(0), 1, 1)
                
                if self.args.pred_diff_noise: ## eps -> estimated-noises
                    
                    basejtsrel_output = self._predict_xstart_from_eps(pert_rel_base_pts_outputs, t=t, eps=basejtsrel_output)
                    
                    avg_jts_outputs = self._predict_xstart_from_eps(pert_avg_joints_sequence, t=t, eps=avg_jts_outputs)
                    
                    # out_dict.update( # 
                    #     model.model.dec_basejtsrel_only_fr_latents(pred_basejtsrel_seq_latents, x['input_data'])
                    # )
                
                ## use the predicted latents and pert_latents for the seq latents prediction ##
                basejtsrel_output_mean, _, _ = self.q_posterior_mean_variance(
                    x_start=basejtsrel_output, x_t=pert_rel_base_pts_outputs, t=t
                )
                
                avg_jts_outputs_mean, _, _ = self.q_posterior_mean_variance(
                    x_start=avg_jts_outputs, x_t=pert_avg_joints_sequence, t=t
                )
                # basejtsrel_seq_latents_mean = basejtsrel_seq_latents_mean.permute(1, 0, 2)
                
                ## from model_variance to basejtsrel_seq_latents ###
                basejtsrel_output_variance = _extract_into_tensor(model_variance, t, basejtsrel_output_mean.shape)
                basejtsrel_output_log_variance = _extract_into_tensor(model_log_variance, t, basejtsrel_output_mean.shape)
                
                ## from model_variance to basejtsrel_seq_latents ###
                avg_jts_outputs_variance = _extract_into_tensor(model_variance, t, avg_jts_outputs_mean.shape)
                avg_jts_outputs_log_variance = _extract_into_tensor(model_log_variance, t, avg_jts_outputs_mean.shape)
            else:
                pert_rel_base_pts_outputs = x['pert_joints_offset_sequence'] # rel base pts outputs #
                basejtsrel_output = out_dict['joints_offset_output']
                # print(f"pert_rel_base_pts_outputs: {pert_rel_base_pts_outputs.size()}, basejtsrel_output: {basejtsrel_output.size()}")
                if self.args.pred_diff_noise: ## eps -> estimated-noises # predict noise? # 
                    # b
                    
                    if self.args.use_var_sched:
                        bsz = basejtsrel_output.size(0)
                        t_item = t[0].item()
                        alpha = self.var_sched.alphas[t_item]
                        alpha_bar = self.var_sched.alpha_bars[t_item]
                        sigma = self.var_sched.get_sigmas(t_item, 0.) # sigma #

                        c0 = 1.0 / torch.sqrt(alpha)
                        c1 = (1 - alpha) / torch.sqrt(1 - alpha_bar)
                        
                        # x_t = traj[t] #
                        beta = self.var_sched.betas[[t[0].item()] * bsz]
                        # if mask is not None:
                        #     x_t = x_t * mask
                        # e_theta = self.net(x_t, beta=beta, context=context)
                        z = torch.randn_like(basejtsrel_output) if t_item > 0 else torch.zeros_like(basejtsrel_output)
                        basejtsrel_output = c0 * (pert_rel_base_pts_outputs - c1 * basejtsrel_output) + sigma * z # theta
                    else:
                        basejtsrel_output = self._predict_xstart_from_eps(pert_rel_base_pts_outputs, t=t, eps=basejtsrel_output)
                    
                    
                    if self.args.phy_guided_sampling and t[0].item() < 200:
                        # phy_guided_sampling #
                        # pred_dec_jts = self._predict_xstart_from_eps(jts_fr_basepts, t, dec_jts_fr_basepts) # bsz x nf x nn_jts x nn_base_pts x 3 
                        pred_dec_jts = basejtsrel_output.clone()
                        # if self.args.sel_basepts_idx >= 0:
                        #     pred_dec_jts = pred_dec_jts[:, :, :, self.args.sel_basepts_idx] # bsz x nf x nn_jts x 3 #
                        # else:
                        #     pred_dec_jts = pred_dec_jts.mean(dim=-2) # bsz x nf x nn_jts x 3 #
                        # phy_projct_pred_joints(self, pred_joints, base_pts, base_normals)
                        joints_proj_dir = self.phy_projct_pred_joints(pred_dec_jts, x['normed_base_pts'], x['base_normals']) # bsz x nf x nn_jts x 3 
                        x_start_projed = pred_dec_jts - joints_proj_dir
                        # x_start_splat = x_start_projed.unsqueeze(-2) - x['normed_base_pts'].unsqueeze(1).unsqueeze(1) # bsz x nf x nn_jts x nn_base_pts x 3 ##
                        # dec_jts_fr_basepts_projed = self._predict_eps_from_xstart(jts_fr_basepts, t, x_start_splat) # bsz x nf x nn_jts x nn_base_pts x 3 #
                        dec_ratio = 0.
                        basejtsrel_output = basejtsrel_output * dec_ratio + x_start_projed * (1. - dec_ratio)
                        
                ## use the predicted latents and pert_latents for the seq latents prediction ##
                # basejtsrel_output_mean, _, _ = self.q_posterior_mean_variance( # q posterior 
                #     x_start=basejtsrel_output, x_t=pert_rel_base_pts_outputs, t=t
                # )
                # ## from model_variance to basejtsrel_seq_latents ###
                # basejtsrel_output_variance = _extract_into_tensor(model_variance, t, basejtsrel_output_mean.shape)
                # basejtsrel_output_log_variance = _extract_into_tensor(model_log_variance, t, basejtsrel_output_mean.shape)
                
            # basejtsrel_output = out_dict["basejtsrel_output"]
            # print(f"basejtsrel_output: {basejtsrel_output.size()}")
            basejtsrel_seq_rt_dict = {
                ### basejtsrel seq latents ###
                # "avg_jts_outputs": avg_jts_outputs,
                # "basejtsrel_output_variance": basejtsrel_output_variance,
                # "basejtsrel_output_log_variance": basejtsrel_output_log_variance,
                # # "avg_jts_outputs_variance": avg_jts_outputs_variance,
                # "avg_jts_outputs_log_variance": avg_jts_outputs_log_variance,
                "basejtsrel_output": basejtsrel_output,
            }
        else:
            basejtsrel_seq_rt_dict = {}
        

        # rt_dict.update(jts_seq_rt_dict)
        rt_dict.update(basejtsrel_seq_rt_dict)
        # rt_dict.update(basejtse_seq_rt_dict)
        rt_dict.update(real_basejtsrel_seq_rt_dict)
        # rt_dict.update(realbasejtsrel_to_joints_rt_dict)
        
        return rt_dict

    def _predict_xstart_from_eps(self, x_t, t, eps):
        assert x_t.shape == eps.shape
        return ( # extract into tensor #
            _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t
            - _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * eps
        )

    def _predict_xstart_from_xprev(self, x_t, t, xprev):
        assert x_t.shape == xprev.shape
        return (  # (xprev - coef2*x_t) / coef1
            _extract_into_tensor(1.0 / self.posterior_mean_coef1, t, x_t.shape) * xprev
            - _extract_into_tensor(
                self.posterior_mean_coef2 / self.posterior_mean_coef1, t, x_t.shape
            )
            * x_t
        )

    def _predict_eps_from_xstart(self, x_t, t, pred_xstart):
        return (
            _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t
            - pred_xstart
        ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)

    def _scale_timesteps(self, t):
        if self.rescale_timesteps:
            return t.float() * (1000.0 / self.num_timesteps)
        return t

    def condition_mean(self, cond_fn, p_mean_var, x, t, model_kwargs=None):
        """
        Compute the mean for the previous step, given a function cond_fn that
        computes the gradient of a conditional log probability with respect to
        x. In particular, cond_fn computes grad(log(p(y|x))), and we want to
        condition on y.

        This uses the conditioning strategy from Sohl-Dickstein et al. (2015). # 
        """
        gradient = cond_fn(x, self._scale_timesteps(t), **model_kwargs)
        new_mean = (
            p_mean_var["mean"].float() + p_mean_var["variance"] * gradient.float()
        )
        return new_mean

    def condition_mean_with_grad(self, cond_fn, p_mean_var, x, t, model_kwargs=None):
        """
        Compute the mean for the previous step, given a function cond_fn that
        computes the gradient of a conditional log probability with respect to
        x. In particular, cond_fn computes grad(log(p(y|x))), and we want to
        condition on y.

        This uses the conditioning strategy from Sohl-Dickstein et al. (2015).
        """
        gradient = cond_fn(x, t, p_mean_var, **model_kwargs)
        new_mean = (
            p_mean_var["mean"].float() + p_mean_var["variance"] * gradient.float()
        )
        return new_mean

    def condition_score(self, cond_fn, p_mean_var, x, t, model_kwargs=None):
        """
        Compute what the p_mean_variance output would have been, should the
        model's score function be conditioned by cond_fn.

        See condition_mean() for details on cond_fn.

        Unlike condition_mean(), this instead uses the conditioning strategy
        from Song et al (2020).
        """
        alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape)

        eps = self._predict_eps_from_xstart(x, t, p_mean_var["pred_xstart"])
        eps = eps - (1 - alpha_bar).sqrt() * cond_fn(
            x, self._scale_timesteps(t), **model_kwargs
        )

        out = p_mean_var.copy()
        out["pred_xstart"] = self._predict_xstart_from_eps(x, t, eps)
        out["mean"], _, _ = self.q_posterior_mean_variance(
            x_start=out["pred_xstart"], x_t=x, t=t
        )
        return out

    def condition_score_with_grad(self, cond_fn, p_mean_var, x, t, model_kwargs=None):
        """
        Compute what the p_mean_variance output would have been, should the
        model's score function be conditioned by cond_fn.

        See condition_mean() for details on cond_fn.

        Unlike condition_mean(), this instead uses the conditioning strategy
        from Song et al (2020).
        """
        alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape)

        eps = self._predict_eps_from_xstart(x, t, p_mean_var["pred_xstart"])
        eps = eps - (1 - alpha_bar).sqrt() * cond_fn(
            x, t, p_mean_var, **model_kwargs
        )

        out = p_mean_var.copy()
        out["pred_xstart"] = self._predict_xstart_from_eps(x, t, eps)
        out["mean"], _, _ = self.q_posterior_mean_variance(
            x_start=out["pred_xstart"], x_t=x, t=t
        )
        return out

    def judge_activated(self, target_setting):
        if target_setting:
            return 1
        else:
            return 0
    
    def p_sample( ## p sample ##
        self,
        model,
        x,
        t,
        clip_denoised=True,
        denoised_fn=None,
        cond_fn=None,
        model_kwargs=None,
        const_noise=False,
    ):
        """ # p_sample for the p_ample #
        Sample x_{t-1} from the model at the given timestep.

        :param model: the model to sample from. # gaussian diffusion #
        :param x: the current tensor at x_{t-1}.
        :param t: the value of t, starting at 0 for the first diffusion step.
        :param clip_denoised: if True, clip the x_start prediction to [-1, 1].
        :param denoised_fn: if not None, a function which applies to the
            x_start prediction before it is used to sample.
        :param cond_fn: if not None, this is a gradient function that acts
                        similarly to the model.
        :param model_kwargs: if not None, a dict of extra keyword arguments to
            pass to the model. This can be used for conditioning.
        :return: a dict containing the following keys:
                 - 'sample': a random sample from the model.
                 - 'pred_xstart': a prediction of x_0.
        """
        
        multi_activated = (
            self.judge_activated(self.diff_jts) + self.judge_activated(self.args.diff_realbasejtsrel_to_joints) + self.judge_activated(self.diff_realbasejtsrel) + self.judge_activated(self.diff_basejtsrel) + self.judge_activated(self.diff_basejtse)
        ) > 1.5
        if multi_activated:
            # print(f"Multiple settings activated! Using combined sampling...")
            p_mena_variance_fn = self.p_mean_variance_cond # p_mean
        else:
            # print(f"Single setting activated! Using single sampling...")
            p_mena_variance_fn = self.p_mean_variance
        
        out = p_mena_variance_fn(
            model,
            x,
            t,
            clip_denoised=clip_denoised,
            denoised_fn=denoised_fn,
            model_kwargs=model_kwargs,
        )
        
        rt_dict = {}
        
        if self.diff_jts:
            # bsz x ws x nnj x nnb x 3 #
            joints_seq_latents_noise = th.randn_like(x['joints_seq_latents'])
            # print('const_noise', const_noise) # seq x bsz x latent_dim #
            
            if const_noise:
                print(f"joints latents hape, ", x['joints_seq_latents'].shape)
                
                joints_seq_latents_noise = joints_seq_latents_noise[[0]].repeat(x['joints_seq_latents'].shape[0], 1, 1)
                # 

            joints_seq_latents_nonzero_mask = (
                (t != 0).float().view(-1, *([1] * (len(x['joints_seq_latents'].shape) - 1)))
            )  # no noise when t == 0
            
            # bsz x nseq x dim #
            #### ==== joints_seq_latents ===== #### # t -> seq for const nosie .... # cnanot dpeict the laten tspace very well... #
            joints_seq_latents_sample = out["joints_seq_latents_mean"].permute(1, 0, 2) + joints_seq_latents_nonzero_mask * th.exp(0.5 * out["joints_seq_latents_log_variance"].permute(1, 0, 2)) * joints_seq_latents_noise.permute(1, 0, 2)
            # nseq x bsz x dim #
            joints_seq_latents_sample = joints_seq_latents_sample.permute(1, 0, 2) # 
            #### ==== joints_seq_latents ===== ####
            
            joint_seq_output = out["joint_seq_output"]
            
            jts_seq_rt_dict = {
                "joints_seq_latents_sample": joints_seq_latents_sample,
                "joint_seq_output": joint_seq_output, 
            }
        else:
            jts_seq_rt_dict = {}
        
        if self.args.diff_realbasejtsrel_to_joints: ## args.pred to joints
            # dec_joints_offset_output = 
            realbasejtsrel_to_joints_rt_dict = {
                'dec_joints_offset_output': out['dec_joints_offset_output']
            }
        else:
            realbasejtsrel_to_joints_rt_dict = {}
            
            
        if self.diff_realbasejtsrel:

            if self.args.train_enc or ( self.args.pred_diff_noise and self.args.use_var_sched):
                real_dec_basejtsrel = out['real_dec_basejtsrel']
            else:
                real_dec_basejtsrel_noise = th.randn_like(out['real_dec_basejtsrel'])
                if const_noise:
                    real_dec_basejtsrel_noise = real_dec_basejtsrel_noise[[0]].repeat(out['real_dec_basejtsrel'].shape[0], 1, 1, 1, 1)
                real_dec_basejtsrel_nonzero_mask = (
                    (t != 0).float().view(-1, *([1] * (len(out['real_dec_basejtsrel'].shape) - 1)))
                )
                real_dec_basejtsrel = out["real_dec_basejtsrel_mean"] + real_dec_basejtsrel_nonzero_mask * th.exp(0.5 * out["real_dec_basejtsrel_log_variance"]) * real_dec_basejtsrel_noise


            real_basejtsrel_rt_dict = {
                'real_dec_basejtsrel': real_dec_basejtsrel,
            }
            
            if self.args.train_enc:
                real_basejtsrel_rt_dict['dec_obj_base_pts_feats'] = out['dec_obj_base_pts_feats']
        else:
            real_basejtsrel_rt_dict = {}
            
        if self.diff_basejtsrel: # baseptse # 
            
            if self.args.pred_diff_noise and self.args.use_var_sched:
                basejtsrel_seq_latents_sample = out['basejtsrel_output']
            else:
                ##### ===== Sample for basejtsrel_seq_latents_sample ===== #####
                ### rel_base_pts_outputs mask ###
                basejtsrel_seq_latents_noise = th.randn_like(out['basejtsrel_output'])
                if const_noise: ## seq latents noise ##
                    basejtsrel_seq_latents_noise = basejtsrel_seq_latents_noise[[0]].repeat(out['basejtsrel_output'].shape[0], 1, 1, 1, 1)
                basejtsrel_seq_latents_nonzero_mask = (
                    (t != 0).float().view(-1, *([1] * (len(out['basejtsrel_output'].shape) - 1)))
                ) # no noise when t == 0
                #### ==== basejtsrel_seq_latents ===== #### ## sample latent codes -> denoise latent codes
                basejtsrel_seq_latents_sample = out["basejtsrel_output"] + basejtsrel_seq_latents_nonzero_mask * th.exp(0.5 * out["basejtsrel_output_log_variance"]) * basejtsrel_seq_latents_noise
                # basejtsrel_seq_latents_sample = basejtsrel_seq_latents_sample.permute(1, 0, 2)
                #### ==== basejtsrel_seq_latents ===== ####
                ##### ===== Sample for basejtsrel_seq_latents_sample ===== #####
            
            basejtsrel_rt_dict = {
                "basejtsrel_seq_latents_sample": basejtsrel_seq_latents_sample, 
                # "avg_jts_outputs_sample": avg_jts_outputs_sample, 
            }
            
        else:
            basejtsrel_rt_dict = {}
          
            
        
        rt_dict.update(jts_seq_rt_dict)
        rt_dict.update(basejtsrel_rt_dict)
        # rt_dict.update(basejtse_rt_dict)
        rt_dict.update(real_basejtsrel_rt_dict)
        rt_dict.update(realbasejtsrel_to_joints_rt_dict)
        
        return rt_dict

    def p_sample_with_grad(
        self,
        model,
        x,
        t,
        clip_denoised=True,
        denoised_fn=None,
        cond_fn=None,
        model_kwargs=None,
    ):
        """
        Sample x_{t-1} from the model at the given timestep.

        :param model: the model to sample from.
        :param x: the current tensor at x_{t-1}.
        :param t: the value of t, starting at 0 for the first diffusion step.
        :param clip_denoised: if True, clip the x_start prediction to [-1, 1].
        :param denoised_fn: if not None, a function which applies to the
            x_start prediction before it is used to sample.
        :param cond_fn: if not None, this is a gradient function that acts
                        similarly to the model.
        :param model_kwargs: if not None, a dict of extra keyword arguments to
            pass to the model. This can be used for conditioning.
        :return: a dict containing the following keys:
                 - 'sample': a random sample from the model.
                 - 'pred_xstart': a prediction of x_0.
        """
        with th.enable_grad():
            x = x.detach().requires_grad_()
            out = self.p_mean_variance(
                model,
                x,
                t,
                clip_denoised=clip_denoised,
                denoised_fn=denoised_fn,
                model_kwargs=model_kwargs,
            )
            noise = th.randn_like(x)
            nonzero_mask = (
                (t != 0).float().view(-1, *([1] * (len(x.shape) - 1)))
            )  # no noise when t == 0
            if cond_fn is not None:
                out["mean"] = self.condition_mean_with_grad(
                    cond_fn, out, x, t, model_kwargs=model_kwargs
                )
        sample = out["mean"] + nonzero_mask * th.exp(0.5 * out["log_variance"]) * noise
        return {"sample": sample, "pred_xstart": out["pred_xstart"].detach()}

    def p_sample_loop(
        self,
        model,
        shape,
        noise=None,
        clip_denoised=True,
        denoised_fn=None,
        cond_fn=None,
        model_kwargs=None,
        device=None,
        progress=False,
        skip_timesteps=0,
        init_image=None,
        randomize_class=False,
        cond_fn_with_grad=False,
        dump_steps=None,
        const_noise=False,
        st_timestep=None,
    ): ## 
        """
        Generate samples from the model.

        :param model: the model module.
        :param shape: the shape of the samples, (N, C, H, W).
        :param noise: if specified, the noise from the encoder to sample.
                      Should be of the same shape as `shape`.
        :param clip_denoised: if True, clip x_start predictions to [-1, 1].
        :param denoised_fn: if not None, a function which applies to the
            x_start prediction before it is used to sample.
        :param cond_fn: if not None, this is a gradient function that acts
                        similarly to the model.
        :param model_kwargs: if not None, a dict of extra keyword arguments to
            pass to the model. This can be used for conditioning.
        :param device: if specified, the device to create the samples on.
                       If not specified, use a model parameter's device.
        :param progress: if True, show a tqdm progress bar.
        :param const_noise: If True, will noise all samples with the same noise throughout sampling
        :return: a non-differentiable batch of samples.
        """
        final = None # 
        if dump_steps is not None: ## dump steps is not None ##
            dump = []

        # function, yield, enumerate! -> 
        for i, sample in enumerate(self.p_sample_loop_progressive(
            model, # p_sample #
            shape, # p_sample #
            noise=noise,
            clip_denoised=clip_denoised,
            denoised_fn=denoised_fn,
            cond_fn=cond_fn,
            model_kwargs=model_kwargs,
            device=device,
            progress=progress,
            skip_timesteps=skip_timesteps,
            init_image=init_image,
            randomize_class=randomize_class,
            cond_fn_with_grad=cond_fn_with_grad,
            const_noise=const_noise, # the same noise #
            st_timestep=st_timestep,
        )):
            if dump_steps is not None and i in dump_steps:
                dump.append(deepcopy(sample))
            final = sample
        if dump_steps is not None:
            return dump
        return final

    # score # socre p_sample_loop_progressive #
    def p_sample_loop_progressive(
        self,
        model,
        shape,
        noise=None,
        clip_denoised=True,
        denoised_fn=None,
        cond_fn=None,
        model_kwargs=None,
        device=None,
        progress=False,
        skip_timesteps=0,
        init_image=None,
        randomize_class=False,
        cond_fn_with_grad=False,
        const_noise=False,
        st_timestep=None,
    ): # 
        """  #  p_sample loop progressive #
        Generate samples from the model and yield intermediate samples from
        each timestep of diffusion.

        Arguments are the same as p_sample_loop().
        Returns a generator over dicts, where each dict is the return value of
        p_sample().
        """
        ####### ==== a conditional ssampling from init_images here!!! ==== #######
        ## === give joints shape here === ##
        ### ==== set the shape for sampling ==== ###
        ### === init image sshould not be none === ###
        base_pts = init_image['base_pts']
        base_normals = init_image['base_normals'] ## base normals ## base normals ##
        # rel_base_pts_to_rhand_joints = init_image['rel_base_pts_to_rhand_joints']
        # dist_base_pts_to_rhand_joints = init_image['dist_base_pts_to_rhand_joints']
        rhand_joints = init_image['rhand_joints']
        # rhand_joints = init_image['gt_rhand_joints']
        
        if self.args.use_anchors:
            # rhand_joints: bsz x nf x nn_anchors x 3 #
            rhand_joints = init_image['pert_rhand_anchors'] ## bsz x nf x nn_anchors x 3 -> for the anchors of the rhand #
        
        
        # rhand_joints = rhand_joints - ## vage for whether this model can work ###
        # avg_joints_sequence = 
        std_joints_sequence = torch.std(rhand_joints, dim=1)
        avg_joints_sequence = torch.mean(rhand_joints, dim=1) # bsz x nnj x 3 ---> for rhand joints ##
        
        ## 
        if self.args.joint_std_v2:
            std_joints_sequence = torch.std(rhand_joints.view(rhand_joints.size(0), -1, 3), dim=1).unsqueeze(1)
            avg_joints_sequence = torch.mean(rhand_joints.view(rhand_joints.size(0), -1, 3), dim=1).unsqueeze(1)
        elif self.args.joint_std_v3:
            avg_joints_sequence = torch.mean(rhand_joints.view(rhand_joints.size(0), -1, 3), dim=1).unsqueeze(1) # ws x 1 x 3 #
            std_joints_sequence = torch.std((rhand_joints - avg_joints_sequence.unsqueeze(1)).view(rhand_joints.size(0), -1), dim=1).unsqueeze(1).unsqueeze(1)
        
        joints_offset_sequence = rhand_joints - avg_joints_sequence.unsqueeze(1) # bsz x nf x nnj x 3 #
        # if self.args.jts_sclae_stra == "std": # and only use the latents #
        #     joints_offset_sequence = joints_offset_sequence / std_joints_sequence
        
        if not self.args.use_arti_obj:
            normed_base_pts = base_pts - avg_joints_sequence # bsz x nnb x 3 
        else:
            # base_pts: bsz x nf x nnb x 3 # 
            normed_base_pts = base_pts - avg_joints_sequence.unsqueeze(1) # bsz x nf x nnb x 3 # 
        
        # joints_offset_sequence_ori = joints_offset_sequence.clone()
        # rhand_joints_ori = rhand_joints.clone()
        
        # jts scale stra # jts scale strategies ##
        # std_joints_sequence.unsqueeze(1), avg_joints_sequence.unsqueeze(1) 
        if self.args.jts_sclae_stra == "std":
            joints_offset_sequence = joints_offset_sequence / std_joints_sequence.unsqueeze(1)
            if not self.args.use_arti_obj:
                normed_base_pts = normed_base_pts / std_joints_sequence
            else:
                normed_base_pts = normed_base_pts / std_joints_sequence.unsqueeze(1)
        else:
            std_joints_sequence = torch.ones_like(std_joints_sequence)
        
        
        if self.args.wo_rel_normalization: # per_frame_avg_joints_rel, per_frame_std_joints_rel
            # init_image['per_frame_avg_joints_rel'] = torch.zeros_like(init_image['per_frame_avg_joints_rel'])
            init_image['per_frame_std_joints_rel'] = torch.ones_like(init_image['per_frame_std_joints_rel'])

        init_image_avg_std_stats = {
            'rhand_joints': init_image['rhand_joints'],
            'per_frame_avg_joints_rel': init_image['per_frame_avg_joints_rel'],
            'per_frame_std_joints_rel': init_image['per_frame_std_joints_rel'],
            'per_frame_avg_joints_dists_rel': init_image['per_frame_avg_joints_dists_rel'],
            'per_frame_std_joints_dists_rel': init_image['per_frame_std_joints_dists_rel'],
        }
        
        if device is None:
            device = next(model.parameters()).device
        assert isinstance(shape, (tuple, list)) 
        
        ### without e normalization ###
        # indicies
        indices = list(range(self.num_timesteps - skip_timesteps))[::-1]
        if st_timestep is not None: ## indices 
            indices = indices[-st_timestep: ]
            print(f"st_timestep: {st_timestep}, indices: {indices}")

        # joints_scaling_factor = 5.

        # rel_base_pts_to_rhand_joints: bsz x ws x nnj x nnb x 3 ##
        # base_pts: bsz x nnb x 3 # # base_normals: bsz x nnb x 3 ##
        # rel_base_pts_to_rhand_joints = rhand_joints.unsqueeze(-2) - base_pts.unsqueeze(1).unsqueeze(1) # bsz x nnj x nnb x 3 # rhand_joints ##
        # init_image['per_frame_avg_joints_rel'] = torch.mean(rel_base_pts_to_rhand_joints, dim=1, keepdim=True)
        # # x_start['per_frame_avg_joints_rel'] = torch
        # # bsz x ws x nnj x nnb x 3 #
        # rel_base_pts_to_rhand_joints = (rel_base_pts_to_rhand_joints - init_image['per_frame_avg_joints_rel']) / init_image['per_frame_std_joints_rel']
        
        if not self.args.use_arti_obj:
            ### rel base pts to rhand joints #joints offset sequence # joints offset sequence # joints # joints_offset_sequence - normed_base_pts
            rel_base_pts_to_rhand_joints = joints_offset_sequence.unsqueeze(-2) - normed_base_pts.unsqueeze(1).unsqueeze(1)  # bsz x nnf x nnj x nnb x 3 --> relative positions from baes pts to rhand joints #
        else:  # 
            rel_base_pts_to_rhand_joints = joints_offset_sequence.unsqueeze(-2) - normed_base_pts.unsqueeze(2)   # bsz x nf x nn_joints x nn_base_pts x 3 #
        
        maxx_rel_base_pts_to_rhand_joints, _ = torch.max(rel_base_pts_to_rhand_joints.view(-1, 3), dim=0)
        minn_rel_base_pts_to_rhand_joints, _ = torch.min(rel_base_pts_to_rhand_joints.view(-1, 3), dim=0)
        
        print(f"maxx_rel_base_pts_to_rhand_joints: {maxx_rel_base_pts_to_rhand_joints}, minn_rel_base_pts_to_rhand_joints: {minn_rel_base_pts_to_rhand_joints}")
        
        if self.args.real_basejtsrel_norm_stra == "mean":
            rel_base_pts_to_rhand_joints = rhand_joints.unsqueeze(-2) - base_pts.unsqueeze(1).unsqueeze(1) 
            bsz, nf, nnj, nnb = rel_base_pts_to_rhand_joints.size()[:4]
            
            # exp_rel_base_pts_to_rhand_joints = rel_base_pts_to_rhand_joints.view(bsz, nf * nnj * nnb, -1).contiguous()
            # avg_rel_base_pts_to_rhand_joints = torch.mean(exp_rel_base_pts_to_rhand_joints, dim=1, keepdim=True) # bsz x 1 x 3 
            # # std_rel_base_pts_to_rhand_joints = torch.std(exp_rel_base_pts_to_rhand_joints, dim=1, keepdim=True) # bsz x 1 x 3
            # #### rel_base_pts_to_rhand_joints
            # rel_base_pts_to_rhand_joints = (rel_base_pts_to_rhand_joints - avg_rel_base_pts_to_rhand_joints.unsqueeze(1).unsqueeze(1)) # / std_rel_base_pts_to_rhand_joints.unsqueeze(1).unsqueeze(1)
            
            rel_base_pts_to_rhand_joints = (rel_base_pts_to_rhand_joints - rel_base_pts_to_rhand_joints.mean(dim=0, keepdim=True))
            
        elif self.args.real_basejtsrel_norm_stra == "std":
            # avg_rel_base_pts_to_rhand_joints, std_rel_base_pts_to_rhand_joints
            # rel_base_pts_to_rhand_joints = rhand_joints.unsqueeze(-2) - base_pts.unsqueeze(1).unsqueeze(1) 
            bsz, nf, nnj, nnb = rel_base_pts_to_rhand_joints.size()[:4] # rel 
            exp_rel_base_pts_to_rhand_joints = rel_base_pts_to_rhand_joints.view(bsz, nf * nnj * nnb, -1).contiguous()
            avg_rel_base_pts_to_rhand_joints = torch.mean(exp_rel_base_pts_to_rhand_joints, dim=1, keepdim=True) # bsz x 1 x 3 
            
            rel_base_pts_to_rhand_joints = rel_base_pts_to_rhand_joints -  avg_rel_base_pts_to_rhand_joints.unsqueeze(1).unsqueeze(1)
            std_rel_base_pts_to_rhand_joints = torch.std(rel_base_pts_to_rhand_joints.view(bsz, -1), dim=-1, keepdim=True).unsqueeze(1) # bsz x 1 x 1
            rel_base_pts_to_rhand_joints = rel_base_pts_to_rhand_joints / std_rel_base_pts_to_rhand_joints.unsqueeze(1).unsqueeze(1)

        my_t = th.ones([shape[0]], device=device, dtype=th.long) * indices[0] # t con
      
        # my_t = th.tensor([indices[-1]] * shape[0], device=device)
        my_t = th.tensor([indices[0]] * shape[0], device=device)

        ####### E #######
        # if not self.args.wo_e_normalization and self.args.e_normalization_stra == "cent": 
        #     bsz = e_disp_rel_to_base_along_normals.size(0)
        #     nf, nnj, nnb = e_disp_rel_to_base_along_normals.size()[1:] # high dimensional
        #     # the max value and min value of all values # #bs z x nnf x nnj x nnb --> for the along normals values and vt normals values ##
        #     maxx_e_disp_rel_to_base_along_normals, _ = torch.max(e_disp_rel_to_base_along_normals.view(bsz, -1), dim=-1)
        #     minn_e_disp_rel_to_base_along_normals , _ = torch.min(e_disp_rel_to_base_along_normals.view(bsz, -1), dim=-1)
        #     maxx_e_disp_rel_to_base_vt_normals , _  = torch.max(e_disp_rel_to_baes_vt_normals.view(bsz, -1), dim=-1)
        #     minn_e_disp_rel_to_base_vt_normals , _  = torch.min(e_disp_rel_to_baes_vt_normals.view(bsz, -1), dim=-1)
        #     maxx_e_disp_rel_to_base_along_normals = maxx_e_disp_rel_to_base_along_normals.unsqueeze(1).unsqueeze(1).unsqueeze(1).repeat(1, 1, nnj, nnb)
        #     minn_e_disp_rel_to_base_along_normals = minn_e_disp_rel_to_base_along_normals.unsqueeze(1).unsqueeze(1).unsqueeze(1).repeat(1, 1, nnj, nnb)
        #     maxx_e_disp_rel_to_base_vt_normals = maxx_e_disp_rel_to_base_vt_normals.unsqueeze(1).unsqueeze(1).unsqueeze(1).repeat(1, 1, nnj, nnb)
        #     minn_e_disp_rel_to_base_vt_normals = minn_e_disp_rel_to_base_vt_normals.unsqueeze(1).unsqueeze(1).unsqueeze(1).repeat(1, 1, nnj, nnb)
            
        #     init_image['per_frame_avg_disp_along_normals'] = (maxx_e_disp_rel_to_base_along_normals + minn_e_disp_rel_to_base_along_normals) / 2.
        #     init_image['per_frame_avg_disp_vt_normals'] = (maxx_e_disp_rel_to_base_vt_normals + minn_e_disp_rel_to_base_vt_normals) / 2.
        #     init_image['per_frame_std_disp_along_normals'] = torch.ones_like(init_image['per_frame_std_disp_along_normals'])
        #     init_image['per_frame_std_disp_vt_normals']  = torch.ones_like(init_image['per_frame_std_disp_vt_normals'] )
            
        # # normalize ##  # base along normals #
        # e_disp_rel_to_base_along_normals = (e_disp_rel_to_base_along_normals - init_image['per_frame_avg_disp_along_normals']) / init_image['per_frame_std_disp_along_normals']
        # e_disp_rel_to_baes_vt_normals = (e_disp_rel_to_baes_vt_normals - init_image['per_frame_avg_disp_vt_normals']) / init_image['per_frame_std_disp_vt_normals']
        ####### E #######
        
        
        
        #### add noise onjts #### to base along normal
        # rigid objects -> moving; global pose # hwo about we do not add those canonicalization? 
        # and we only need correct contacts? to model # attaction forces? attraction forces # distances? 
        # distances? # k_f = e^{-k\cdot \Vert v_o - v_h\Vert}; --> the proximity value between each pair of points; --> points on the object one object denoising targets --> the distance from hand joint to the object surface; 
        # distance values --> distance values # manipulate the object --> add forces to the object #
        # manipulate the object --> add forces to the object # # map 
        # a simple case -> map joint points to the object points -> denoise relative positions; realtive positions; joint trajectory; 
        # values that describe the consistency between moving (value negative propotional to distances) * exp(\Vert x_o - x_h\Vert_2) ---> to describe the moving consistency between the hand and the object. 
        # contact map -> or a generatlized contact map -> 
        # add_noise_onjts, add_noise_onjts_single
        # noise_rel_base_pts_to_rhand_joints, rel_base_pts_to_rhand_joints
        ## pred jts # pred jts #
        ### Add noise to rel_baes_pts_to_rhand_joints ### # 
        noise_rel_base_pts_to_rhand_joints = th.randn_like(rel_base_pts_to_rhand_joints) ## bsz x ws x nnjts x 3 -> for each joint point #
        pert_rel_base_pts_to_rhand_joints = self.q_sample( ## bsz x ws x nnjts x 3 --> pert in the spatial space ### # space spatial #
            rel_base_pts_to_rhand_joints, my_t, noise_rel_base_pts_to_rhand_joints
        )
        
        if self.args.add_noise_onjts: ### add noise on joints ###
            if self.args.use_same_noise_for_rep: ### use same noise for rep ###
                noise_joints_offset_output = torch.randn_like(joints_offset_sequence)
                pert_joints_offset_output = self.q_sample( ## bsz x ws x nnjts x 3 --> pert in the spatial space ###
                    joints_offset_sequence, my_t, noise_joints_offset_output
                )
                if not self.args.use_arti_obj:
                    pert_rel_base_pts_to_rhand_joints = pert_joints_offset_output.unsqueeze(-2) - normed_base_pts.unsqueeze(1).unsqueeze(1) 
                    noise_rel_base_pts_to_rhand_joints = noise_joints_offset_output.unsqueeze(-2).repeat(1, 1, 1, normed_base_pts.size(1), 1)
                else:
                    pert_rel_base_pts_to_rhand_joints = pert_joints_offset_output.unsqueeze(-2) - normed_base_pts.unsqueeze(2)
                    noise_rel_base_pts_to_rhand_joints = noise_joints_offset_output.unsqueeze(-2).repeat(1, 1, 1, normed_base_pts.size(2), 1)
            else:
                if not self.args.use_arti_obj:
                    joints_offset_output_exp = joints_offset_sequence.unsqueeze(-2).repeat(1, 1, 1, normed_base_pts.size(1), 1)
                    noise_rel_base_pts_to_rhand_joints = torch.randn_like(joints_offset_output_exp) # normed_base_pts
                    pert_joints_offset_output_exp = self.q_sample( ## bsz x ws x nnjts x 3 --> pert in the spatial space ###
                        joints_offset_output_exp, my_t, noise_rel_base_pts_to_rhand_joints
                    )
                    pert_rel_base_pts_to_rhand_joints = pert_joints_offset_output_exp - normed_base_pts.unsqueeze(1).unsqueeze(1) 
                else:
                    joints_offset_output_exp = joints_offset_sequence.unsqueeze(-2).repeat(1, 1, 1, normed_base_pts.size(2), 1)
                    noise_rel_base_pts_to_rhand_joints = torch.randn_like(joints_offset_output_exp) # normed_base_pts
                    pert_joints_offset_output_exp = self.q_sample( ## bsz x ws x nnjts x 3 --> pert in the spatial space ###
                        joints_offset_output_exp, my_t, noise_rel_base_pts_to_rhand_joints
                    )
                    pert_rel_base_pts_to_rhand_joints = pert_joints_offset_output_exp - normed_base_pts.unsqueeze(2)
                
        elif self.args.add_noise_onjts_single: # joints offset sequence # joints offset single #
            noise_joints_offset_output = torch.randn_like(joints_offset_sequence)
            pert_joints_offset_output = self.q_sample( ## bsz x ws x nnjts x 3 --> pert in the spatial space ###
                joints_offset_sequence, my_t, noise_joints_offset_output
            )
            if not self.args.use_arti_obj:
                pert_rel_base_pts_to_rhand_joints = pert_joints_offset_output.unsqueeze(-2) - normed_base_pts.unsqueeze(1).unsqueeze(1) 
                noise_rel_base_pts_to_rhand_joints = noise_joints_offset_output.unsqueeze(-2).repeat(1, 1, 1, normed_base_pts.size(1), 1)
            else:
                pert_rel_base_pts_to_rhand_joints = pert_joints_offset_output.unsqueeze(-2) - normed_base_pts.unsqueeze(2)
                noise_rel_base_pts_to_rhand_joints = noise_joints_offset_output.unsqueeze(-2).repeat(1, 1, 1, normed_base_pts.size(2), 1)
            
        if self.args.train_enc:
            pert_rel_base_pts_to_rhand_joints = rel_base_pts_to_rhand_joints
        

        # ### Add noise to rel_baes_pts_to_rhand_joints ###
        # noise_rel_base_pts_to_rhand_joints = th.randn_like(rel_base_pts_to_rhand_joints) ## bsz x ws x nnjts x 3 -> for each joint point #
        # pert_rel_base_pts_to_rhand_joints = self.q_sample( ## bsz x ws x nnjts x 3 --> pert in the spatial space ###
        #     rel_base_pts_to_rhand_joints, my_t, noise_rel_base_pts_to_rhand_joints
        # )
        
        # noise_avg_joints_sequence = th.randn_like(avg_joints_sequence)
        # pert_avg_joints_sequence = self.q_sample( # bsz x nnjts x 3 --> the avg joitns for each ...
        #     avg_joints_sequence, my_t, noise_avg_joints_sequence
        # )
        
        # joints_offset_sequence # joints offset sequence ##
        noise_joints_offset_sequence = th.randn_like(joints_offset_sequence)
        print(f"my_t: {my_t}")
        pert_joints_offset_sequence = self.q_sample(
            joints_offset_sequence, my_t, noise_joints_offset_sequence
        )
        
        if self.args.add_noise_onjts_single:
            noise_joints_offset_sequence = noise_joints_offset_output
            pert_joints_offset_sequence = pert_joints_offset_output
            
        if not self.args.use_arti_obj:
            if self.args.add_noise_onjts_single or (self.diff_realbasejtsrel and self.diff_basejtsrel):
                pert_rel_base_pts_to_rhand_joints = pert_joints_offset_sequence.unsqueeze(-2) - normed_base_pts.unsqueeze(1).unsqueeze(1) 
            
            if self.args.finetune_with_cond:
                pert_rel_base_pts_to_rhand_joints = pert_joints_offset_sequence.unsqueeze(-2) - normed_base_pts.unsqueeze(1).unsqueeze(1) 
                print(f"pert_rel_base_pts_to_rhand_joints: {pert_rel_base_pts_to_rhand_joints.size()}")
            
        else:
            if self.args.add_noise_onjts_single or (self.diff_realbasejtsrel and self.diff_basejtsrel):
                pert_rel_base_pts_to_rhand_joints = pert_joints_offset_sequence.unsqueeze(-2) - normed_base_pts.unsqueeze(2)
            
            if self.args.finetune_with_cond:
                pert_rel_base_pts_to_rhand_joints = pert_joints_offset_sequence.unsqueeze(-2) - normed_base_pts.unsqueeze(2)
                print(f"pert_rel_base_pts_to_rhand_joints: {pert_rel_base_pts_to_rhand_joints.size()}")
            
            
        # sv_pert_dict = {
        #     'joints_offset_sequence': joints_offset_sequence.detach().cpu().numpy(),
        #     'pert_joints_offset_sequence': pert_joints_offset_sequence.detach().cpu().numpy(),
        #     'noise_joints_offset_sequence': noise_joints_offset_sequence.detach().cpu().numpy(),
        #     'joints_offset_sequence_ori': joints_offset_sequence_ori.detach().cpu().numpy(),
        #     'rhand_joints_ori': rhand_joints.detach().cpu().numpy(),
        # }
        
        # sv_pert_dict_fn = "tot_pert_jts_sequence_dict.npy" # this file @!!!!!
        # np.save(sv_pert_dict_fn, sv_pert_dict)
        # print(f"pert data saved to {sv_pert_dict_fn} !!!!")
        
        if self.args.rnd_noise:
            pert_joints_offset_sequence = noise_joints_offset_sequence
            # pert_avg_joints_sequence = noise_avg_joints_sequence
            
        if not self.args.use_arti_obj:
            ## minus normed base pts here ## # ### normed 
            pert_rel_base_pts_to_joints_for_jts_pred = pert_joints_offset_sequence.unsqueeze(-2) - normed_base_pts.unsqueeze(1).unsqueeze(1) 
        else:
            pert_rel_base_pts_to_joints_for_jts_pred = pert_joints_offset_sequence.unsqueeze(-2) - normed_base_pts.unsqueeze(2)
        
        # the strategy of adding noise to the representations #
        
        # tot_pert_joint = pert_joints_offset_sequence * std_joints_sequence.unsqueeze(1) + pert_avg_joints_sequence.unsqueeze(1)
        # np.save("tot_pert_joint.npy", tot_pert_joint.detach().cpu().numpy())
        
        ####### E #######
        # # noise_e_disp_rel_to_base_along_normals, noise_e_disp_rel_to_base_vt_normals  #
        # # per_frame_avg_disp_along_normals, per_frame_std_disp_along_normals # wo e normalization;
        # # per_frame_avg_disp_vt_normals, per_frame_std_disp_vt_normals # wo_e_normalization #
        # noise_e_disp_rel_to_base_along_normals = torch.randn_like(e_disp_rel_to_base_along_normals)
        # pert_e_disp_rel_to_base_along_normals = self.q_sample(
        #     e_disp_rel_to_base_along_normals, my_t, noise_e_disp_rel_to_base_along_normals  
        # )
        
        # noise_e_disp_rel_to_base_vt_normals = torch.randn_like(e_disp_rel_to_baes_vt_normals)
        # pert_e_disp_rel_to_base_vt_normals = self.q_sample(
        #     e_disp_rel_to_baes_vt_normals, my_t, noise_e_disp_rel_to_base_vt_normals
        # )
        ####### E #######
        
        
        input_data = {
            'base_pts': base_pts,
            'base_normals': base_normals,
            # 'rel_base_pts_to_rhand_joints': rel_base_pts_to_pert_rhand_joints, 
            # 'dist_base_pts_to_rhand_joints': dist_base_pts_to_pert_rhand_joints,
            # 'pert_rhand_joints': pert_normed_rhand_joints,
            # 'pert_rhand_joints': pert_scaled_rhand_joints,
            'rhand_joints': rhand_joints, # 
            'rel_base_pts_to_rhand_joints': rel_base_pts_to_rhand_joints,
            # 'rel_base_pts_to_rhand_joints': rel_base_pts_to_rhand_joints.clone(),
            # 'avg_joints_sequence': avg_joints_sequence,
            # 'pert_avg_joints_sequence': pert_avg_joints_sequence, ## pert avg joints sequence 
            # 'pert_rel_base_pts_to_rhand_joints': pert_rel_base_pts_to_rhand_joints, ## pert realtive base pts to rhand joints ##
            'pert_joints_offset_sequence': pert_joints_offset_sequence,
            'pert_rel_base_pts_to_rhand_joints': pert_rel_base_pts_to_rhand_joints, ## pert realtive base pts to rhand joints ##
            # 'pert_joints_offset_sequence': pert_joints_offset_sequence,
            'normed_base_pts': normed_base_pts,
            'pert_rel_base_pts_to_joints_for_jts_pred': pert_rel_base_pts_to_joints_for_jts_pred, ## pert_rel_base_pts_to_joints_for_jts_pred for the bsz x nf x nnj x nnb x 3 --> from base points to joints ####
        }
        
        ####### E #######
        # primal space denoising ->
        # input_data.update(
        #     {
        #         'e_disp_rel_to_base_along_normals': e_disp_rel_to_base_along_normals, 
        #         'e_disp_rel_to_baes_vt_normals': e_disp_rel_to_baes_vt_normals,
        #         'pert_e_disp_rel_to_base_along_normals': pert_e_disp_rel_to_base_along_normals,
        #         'pert_e_disp_rel_to_base_vt_normals': pert_e_disp_rel_to_base_vt_normals,
        #     }
        # )
        ####### E #######
        
        # input #
        input_data.update(init_image_avg_std_stats)
        input_data['rhand_joints'] = rhand_joints # normed 
        
        # self.args.real_basejtsrel_norm_stra == "std":
            # avg_rel_base_pts_to_rhand_joints, std_rel_base_pts_to_rhand_joints
        if self.args.real_basejtsrel_norm_stra == "std":
            input_data.update(
                {
                    'avg_rel_base_pts_to_rhand_joints': avg_rel_base_pts_to_rhand_joints, 
                    'std_rel_base_pts_to_rhand_joints': std_rel_base_pts_to_rhand_joints,
                }
            )

        if self.args.train_enc: #
            # model(input_data, self._scale_timesteps(t).clone())
            out_dict = model(input_data, self._scale_timesteps(my_t).clone())
            obj_base_pts_feats = out_dict['obj_base_pts_feats'] # obj base pts feats #
            # noise_obj_base_pts_feats = torch.zeros_like(obj_base_pts_feats)
            noise_obj_base_pts_feats = torch.randn_like(obj_base_pts_feats)
            pert_obj_base_pts_feats = self.q_sample(
                obj_base_pts_feats.permute(1, 0, 2), my_t, noise_obj_base_pts_feats.permute(1, 0, 2)
            ).permute(1, 0, 2)
            
            if self.args.rnd_noise:
                pert_obj_base_pts_feats = noise_obj_base_pts_feats
            
            input_data['pert_obj_base_pts_feats'] = pert_obj_base_pts_feats


        model_kwargs = {
            k: val for k, val in init_image.items() if k not in input_data
        }

        if progress:
            # Lazy import so that we don't depend on tqdm. # 
            from tqdm.auto import tqdm
            indices = tqdm(indices)

        for i_idx, i in enumerate(indices):
            t = th.tensor([i] * shape[0], device=device)
            if randomize_class and 'y' in model_kwargs:
                model_kwargs['y'] = th.randint(low=0, high=model.num_classes,
                                               size=model_kwargs['y'].shape, # size of y 
                                               device=model_kwargs['y'].device) # device of y 
            with th.no_grad(): # inter_optim # progress #
                # p_sample_with_grad ##  p_sample with grid ##s # or for each joints -> the features ->
                sample_fn = self.p_sample_with_grad if cond_fn_with_grad else self.p_sample

                out = sample_fn( 
                    model,
                    input_data, ## sample from input data ##
                    t,
                    clip_denoised=clip_denoised,
                    denoised_fn=denoised_fn,
                    cond_fn=cond_fn,
                    model_kwargs=model_kwargs, ## 
                    const_noise=const_noise, ## # new representation strategies; resolve penerations; resolve penerations ### penetrations for new representations ##
                )

            if self.diff_basejtsrel: # basejtrel #
                basejtsrel_seq_latents_sample = out["basejtsrel_seq_latents_sample"] ## basejtsrle output ## ## ## basejtsrel output ##
                # 'real_dec_basejtsrel': real_dec_basejtsrel, 
                        # 'dec_obj_base_pts_feats': dec_obj_base_pts_feats,
                # if self.args.pred_joints_offset: # pred 
                # basejtsrel_seq_latents_sample: bsz x nf x nnj x 3  # basejtsrel_seq_latents_sample --> basejtsrel_seq_latents_sample #
                # sampled_rhand_joints = basejtsrel_seq_latents_sample  * std_joints_sequence.unsqueeze(1) + avg_jts_outputs_sample.unsqueeze(1)
                sampled_rhand_joints = basejtsrel_seq_latents_sample * std_joints_sequence.unsqueeze(1) + avg_joints_sequence.unsqueeze(1)

                # print(f"basejtsrel_seq_latents_sample: {basejtsrel_seq_latents_sample.size()}, normed_base_pts: {normed_base_pts.size()}")
                ### pert rel bae pts to rhand joints ### # ### normed base pts ##
                
                if not self.args.use_arti_obj:
                    pert_rel_base_pts_to_rhand_joints = basejtsrel_seq_latents_sample.unsqueeze(-2) - normed_base_pts.unsqueeze(1).unsqueeze(1) 
                else:
                    pert_rel_base_pts_to_rhand_joints = basejtsrel_seq_latents_sample.unsqueeze(-2) - normed_base_pts.unsqueeze(2)
                # print(f"Sampling with pert_rel_base_pts_to_rhand_joints: {pert_rel_base_pts_to_rhand_joints.size()}")

                basejtsrel_seq_dec_in_dict = {
                    # finetune_with_cond
                    # 'pert_avg_joints_sequence': out["avg_jts_outputs_sample"] if 'avg_jts_outputs_sample' in out else pert_avg_joints_sequence, ## for avg-jts sequence ##
                    'pert_rel_base_pts_to_rhand_joints': pert_rel_base_pts_to_rhand_joints, ## pert realtive base pts to rhand joints ##
                    'sampled_rhand_joints': sampled_rhand_joints, ## sampled rhand joints ## # rhand joints ## ## and another choice ## another choice ##
                    'pert_joints_offset_sequence':  out["basejtsrel_seq_latents_sample"],
                    
                }
                input_data.update(basejtsrel_seq_dec_in_dict)
            else:
                # basejtsrel_seq_input_dict = {}
                basejtsrel_seq_dec_in_dict = {}
            
            if self.diff_realbasejtsrel : # 
                real_dec_basejtsrel = out["real_dec_basejtsrel"] # bsz x nf x nnj x nnb x 3 #
                # avg_rel_base_pts_to_rhand_joints, std_rel_base_pts_to_rhand_joints.unsqueeze(1).unsqueeze(1)
                # add_noise_onjts, add_noise_onjts_single #### add_noise_onjts; add_noise_onjts_single ####
                if self.args.real_basejtsrel_norm_stra == "std" and (not self.args.add_noise_onjts) and (not self.args.add_noise_onjts_single):
                    real_dec_basejtsrel_pred_sample = real_dec_basejtsrel * std_rel_base_pts_to_rhand_joints.unsqueeze(1).unsqueeze(1) + avg_rel_base_pts_to_rhand_joints.unsqueeze(1).unsqueeze(1)
                    real_dec_basejtsrel_pred_sample = real_dec_basejtsrel_pred_sample + normed_base_pts.unsqueeze(1).unsqueeze(1)
                else:
                    # real_dec_basejtsrel_pred_sample = real_dec_basejtsrel.clone()
                    if not self.args.use_arti_obj:
                        real_dec_basejtsrel_pred_sample = real_dec_basejtsrel + normed_base_pts.unsqueeze(1).unsqueeze(1) 
                    else:
                        real_dec_basejtsrel_pred_sample = real_dec_basejtsrel + normed_base_pts.unsqueeze(2)
                # real dec basejtsrel pred sample #
                # real_dec_basejtsrel_pred_sample = real_dec_basejtsrel_pred_sample + normed_base_pts.unsqueeze(1).unsqueeze(1) # bsz x nf x nnj x nnb x 3 #
                # # std_joints_sequence.unsqueeze(1), avg_joints_sequence.unsqueeze(1)  # real pred samples #
                # if self.args.use_t == 1000:
                #     sampled_rhand_joints = real_dec_basejtsrel_pred_sample[..., self.args.sel_basepts_idx, :] 
                #     # sampled_rhand_joints = real_dec_basejtsrel_pred_sample.mean(dim=-2) 
                # else:
                #     sampled_rhand_joints = real_dec_basejtsrel_pred_sample.mean(dim=-2) # bsz x nf x nnj x 3 #
                    # sampled_rhand_joints = real_dec_basejtsrel_pred_sample[..., self.args.sel_basepts_idx, :] 
                if self.args.sel_basepts_idx >= 0:
                    sampled_rhand_joints = real_dec_basejtsrel_pred_sample[..., self.args.sel_basepts_idx, :] 
                else:
                    sampled_rhand_joints = real_dec_basejtsrel_pred_sample.mean(dim=-2) 
                # sampled_rhand_joints = real_dec_basejtsrel_pred_sample[..., 0, :] # std joints sequence; 
                # std_joints #
                sampled_rhand_joints = sampled_rhand_joints * std_joints_sequence.unsqueeze(1) +  avg_joints_sequence.unsqueeze(1) 
                
                real_basejtsrel_dec_in_dict = { # real_dec_basejtsrel #
                    'pert_rel_base_pts_to_rhand_joints': real_dec_basejtsrel, ## realdecbasejtsrel
                    # 'sampled_rhand_joints': sampled_rhand_joints,
                }
                if not self.diff_basejtsrel:
                    real_basejtsrel_dec_in_dict['sampled_rhand_joints'] = sampled_rhand_joints
                if self.args.train_enc:
                    real_basejtsrel_dec_in_dict['pert_obj_base_pts_feats'] = out['dec_obj_base_pts_feats']
                input_data.update(real_basejtsrel_dec_in_dict)
            else:
                real_basejtsrel_dec_in_dict = {}
                
            
            
            yield input_data
                        


    def _vb_terms_bpd(
        self, model, x_start, x_t, t, clip_denoised=True, model_kwargs=None
    ):
        """
        Get a term for the variational lower-bound.

        The resulting units are bits (rather than nats, as one might expect).
        This allows for comparison to other papers.

        :return: a dict with the following keys:
                 - 'output': a shape [N] tensor of NLLs or KLs.
                 - 'pred_xstart': the x_0 predictions.
        """
        true_mean, _, true_log_variance_clipped = self.q_posterior_mean_variance(
            x_start=x_start, x_t=x_t, t=t
        )
        out = self.p_mean_variance(
            model, x_t, t, clip_denoised=clip_denoised, model_kwargs=model_kwargs
        )
        kl = normal_kl(
            true_mean, true_log_variance_clipped, out["mean"], out["log_variance"]
        )
        kl = mean_flat(kl) / np.log(2.0)

        decoder_nll = -discretized_gaussian_log_likelihood(
            x_start, means=out["mean"], log_scales=0.5 * out["log_variance"]
        )
        assert decoder_nll.shape == x_start.shape
        decoder_nll = mean_flat(decoder_nll) / np.log(2.0)

        # At the first timestep return the decoder NLL,
        # otherwise return KL(q(x_{t-1}|x_t,x_0) || p(x_{t-1}|x_t))
        output = th.where((t == 0), decoder_nll, kl)
        return {"output": output, "pred_xstart": out["pred_xstart"]}


    ## 
    ## training losses ## ## training losses ##
    def training_losses(self, model, x_start, t, model_kwargs=None, noise=None, dataset=None):
        """ # training losses # training losses for rel/dist representations ## 
        Compute training losses for a single timestep.

        :param model: the model to evaluate loss on.
        :param x_start: the [N x C x ...] tensor of inputs.
        :param t: a batch of timestep indices.
        :param model_kwargs: if not None, a dict of extra keyword arguments to
            pass to the model. This can be used for conditioning.
        :param noise: if specified, the specific Gaussian noise to try to remove.
        :return: a dict with the key "loss" containing a tensor of shape [N].
                 Some mean or variance settings may also have other keys.
        """
        
        enc = model.model ## model.model
        mask = model_kwargs['y']['mask'] ## rot2xyz
        get_xyz = lambda sample: enc.rot2xyz(sample, mask=None, pose_rep=enc.pose_rep, translation=enc.translation,
                                             glob=enc.glob, ## rot2xyz ## ## rot2xyz ##
                                             # jointstype='vertices',  # 3.4 iter/sec # USED ALSO IN MotionCLIP
                                             jointstype='smpl',  # 3.4 iter/sec
                                             vertstrans=False)
        # bsz x ws x nnj x 3 # #
        # base_pts: bsz x nnb x 3 #
        # base_normals: bsz x nnb x 3 # # base normals  # base normals #
        # bsz x ws x nnjts x 3 #
        rhand_joints = x_start['rhand_joints']
        # bsz x nnbase x 3 #
        base_pts = x_start['base_pts']
        # bsz x ws x nnbase x 3 #
        base_normals = x_start['base_normals']
        
        # 
        if self.args.use_anchors:
            # rhand_joints: bsz x nf x nn_anchors x 3 # ## rhand verts ##
            rhand_joints = x_start['rhand_anchors'] ## bsz x nf x nn_anchors x 3 -> for the anchors of the rhand #
        
        # base_pts, base_normals, rhand_joints # ### rhand verts ##
        avg_joints_sequence = torch.mean(rhand_joints, dim=1) # bsz x nnj x 3 ---> for rhand joints ##
        std_joints_sequence = torch.std(rhand_joints, dim=1)
        
        if self.args.joint_std_v2:
            std_joints_sequence = torch.std(rhand_joints.view(rhand_joints.size(0), -1, 3), dim=1).unsqueeze(1)
            avg_joints_sequence = torch.mean(rhand_joints.view(rhand_joints.size(0), -1, 3), dim=1).unsqueeze(1)
        elif self.args.joint_std_v3:
            # std_joints_sequence = torch.std(rhand_joints.view(rhand_joints.size(0), -1, 3), dim=1).unsqueeze(1)
            avg_joints_sequence = torch.mean(rhand_joints.view(rhand_joints.size(0), -1, 3), dim=1).unsqueeze(1) # bsz x 3 --> bsz x 1 x 3; 
            std_joints_sequence = torch.std((rhand_joints - avg_joints_sequence.unsqueeze(1)).view(rhand_joints.size(0), -1), dim=1).unsqueeze(1).unsqueeze(1)
            
        # if self.args.use_anchor:
        #     avg_joints_sequence = torch.mean(rhand_joints, dim=1) 
            
        # normed_base_pts, joints_offset_sequence # 
        joints_offset_sequence = rhand_joints - avg_joints_sequence.unsqueeze(1) # bsz x nf x nnj x 3 # # bsz x nf x nnj x 3
        
        # normed_base_pts = base_pts - avg_joints_sequence # bsz x nnb x 3 
        if not self.args.use_arti_obj:
            normed_base_pts = base_pts - avg_joints_sequence # bsz x nnb x 3 
        else:
            # base_pts: bsz x nf x nnb x 3 # 
            normed_base_pts = base_pts - avg_joints_sequence.unsqueeze(1) # bsz x nf x nnb x 3 # 
        
        
        if self.args.jts_sclae_stra == "std": ## jts scale stra ##
            joints_offset_sequence = joints_offset_sequence / std_joints_sequence.unsqueeze(1)
            # normed_base_pts = normed_base_pts / std_joints_sequence
            if not self.args.use_arti_obj:
                normed_base_pts = normed_base_pts / std_joints_sequence
            else:
                normed_base_pts = normed_base_pts / std_joints_sequence.unsqueeze(1)
        else:
            std_joints_sequence = torch.ones_like(std_joints_sequence)
        
        # # bsz x ws x nnjts x nnbase x 3 #
        # rel_base_pts_to_rhand_joints = x_start['rel_base_pts_to_rhand_joints']
        # # bsz x ws x nnjts x nnbase #
        # dist_base_pts_to_rhand_joints = x_start['dist_base_pts_to_rhand_joints']
        # if 'sampled_base_pts_nearest_obj_pc' in x_start:
        #     ambient_xstart_dict = {
        #         'sampled_base_pts_nearest_obj_pc': x_start['sampled_base_pts_nearest_obj_pc'],
        #         'sampled_base_pts_nearest_obj_vns': x_start['sampled_base_pts_nearest_obj_vns'],
        #     }

        # per_frame_avg_disp_along_normals, per_frame_std_disp_along_normals # wo e normalization;
        # per_frame_avg_disp_vt_normals, per_frame_std_disp_vt_normals # wo_e_normalization # 
        if self.args.wo_e_normalization: # per frame avg disp along normals #
            x_start['per_frame_avg_disp_along_normals'] = torch.zeros_like(x_start['per_frame_avg_disp_along_normals'])
            x_start['per_frame_avg_disp_vt_normals'] = torch.zeros_like(x_start['per_frame_avg_disp_vt_normals'])
            x_start['per_frame_std_disp_along_normals'] = torch.ones_like(x_start['per_frame_std_disp_along_normals'])
            x_start['per_frame_std_disp_vt_normals'] = torch.ones_like(x_start['per_frame_std_disp_vt_normals'])
       
        
        # psatial -> e normalization and centralize? 
            

        if self.args.wo_rel_normalization: # per_frame_avg_joints_rel, per_frame_std_joints_rel
            # x_start['per_frame_avg_joints_rel'] = torch.zeros_like(x_start['per_frame_avg_joints_rel'])
            x_start['per_frame_std_joints_rel'] = torch.ones_like(x_start['per_frame_std_joints_rel'])


        # normed_base_pts, joints_offset_sequence # 
        ## rel_base_pts_to_rhand_joints: bsz x ws x nnj x nnb x 3 ## ## base pts to rhand joints ##
        # base_pts: bsz x nnb x 3 # # base_normals: bsz x nnb x 3 ## ## relative joint positions ### ## bsz x 
        
        # rel_base_pts_to_rhand_joints = rhand_joints.unsqueeze(-2) - base_pts.unsqueeze(1).unsqueeze(1) # bsz x nnj x nnb x 3 # rhand_joints ##
        # x_start['per_frame_avg_joints_rel'] = torch.mean(rel_base_pts_to_rhand_joints, dim=1, keepdim=True)
        # # x_start['per_frame_avg_joints_rel'] = torch
        # # bsz x ws x nnj x nnb x 3 # # per_frame_avg_joints_rel #
        # rel_base_pts_to_rhand_joints = (rel_base_pts_to_rhand_joints - x_start['per_frame_avg_joints_rel']) / x_start['per_frame_std_joints_rel']
        
        # rel_base_pts_to_rhand_joints
        ## rel_base_pts_to_rhand_joints -> joints offset
        ## Normalization stra 1 --> no normalization for joints sequences ##  # normed base pts 
        if not self.args.use_arti_obj:
            rel_base_pts_to_rhand_joints = joints_offset_sequence.unsqueeze(-2) - normed_base_pts.unsqueeze(1).unsqueeze(1)  # bsz x nnf x nnj x nnb x 3 --> relative positions from baes pts to rhand joints #
        else:
            rel_base_pts_to_rhand_joints = joints_offset_sequence.unsqueeze(-2) - normed_base_pts.unsqueeze(2)
        
        # other normalization strategies>
        if self.args.real_basejtsrel_norm_stra == "mean":
            rel_base_pts_to_rhand_joints = rhand_joints.unsqueeze(-2) - base_pts.unsqueeze(1).unsqueeze(1) 
            bsz, nf, nnj, nnb = rel_base_pts_to_rhand_joints.size()[:4]
            
            # exp_rel_base_pts_to_rhand_joints = rel_base_pts_to_rhand_joints.view(bsz, nf * nnj * nnb, -1).contiguous()
            # avg_rel_base_pts_to_rhand_joints = torch.mean(exp_rel_base_pts_to_rhand_joints, dim=1, keepdim=True) # bsz x 1 x 3 
            # # std_rel_base_pts_to_rhand_joints = torch.std(exp_rel_base_pts_to_rhand_joints, dim=1, keepdim=True) # bsz x 1 x 3
            # #### rel_base_pts_to_rhand_joints ####
            # rel_base_pts_to_rhand_joints = (rel_base_pts_to_rhand_joints - avg_rel_base_pts_to_rhand_joints.unsqueeze(1).unsqueeze(1)) # / std_rel_base_pts_to_rhand_joints.unsqueeze(1).unsqueeze(1)
            
            rel_base_pts_to_rhand_joints = (rel_base_pts_to_rhand_joints - rel_base_pts_to_rhand_joints.mean(dim=0, keepdim=True))
            
        elif self.args.real_basejtsrel_norm_stra == "std":
            # rel_base_pts_to_rhand_joints = rhand_joints.unsqueeze(-2) - base_pts.unsqueeze(1).unsqueeze(1) 
            bsz, nf, nnj, nnb = rel_base_pts_to_rhand_joints.size()[:4]
            exp_rel_base_pts_to_rhand_joints = rel_base_pts_to_rhand_joints.view(bsz, nf * nnj * nnb, -1).contiguous()
            avg_rel_base_pts_to_rhand_joints = torch.mean(exp_rel_base_pts_to_rhand_joints, dim=1, keepdim=True) # bsz x 1 x 3 
            
            rel_base_pts_to_rhand_joints = rel_base_pts_to_rhand_joints -  avg_rel_base_pts_to_rhand_joints.unsqueeze(1).unsqueeze(1)
            std_rel_base_pts_to_rhand_joints = torch.std(rel_base_pts_to_rhand_joints.view(bsz, -1), dim=-1, keepdim=True).unsqueeze(1) # bsz x 1 x 1
            rel_base_pts_to_rhand_joints = rel_base_pts_to_rhand_joints / std_rel_base_pts_to_rhand_joints.unsqueeze(1).unsqueeze(1)
            # rel_base_pts_to_rhand_joints --> # 
            # print("here using std!!")
            # std_rel_base_pts_to_rhand_joints = torch.std(exp_rel_base_pts_to_rhand_joints, dim=1, keepdim=True) # bsz x 1 x 3
            #### rel_base_pts_to_rhand_joints ####
            # rel_base_pts_to_rhand_joints = (rel_base_pts_to_rhand_joints - avg_rel_base_pts_to_rhand_joints.unsqueeze(1).unsqueeze(1)) / std_rel_base_pts_to_rhand_joints.unsqueeze(1).unsqueeze(1)
        
        
        ## rep; motion-to-rep; 
        # construct statistics, normalize values #
        # joints_scaling_factor = 5. # 
        ''' GET rel and dists '''  ## rep and rhand_joints ##### # rep; motion-to-rep #
        if self.denoising_stra == "rep":
            
            # rel_base_pts_to_rhand_joints: bsz x nf x nnj x nnb x 3 #
            if not self.args.use_arti_obj:
                denormed_rel_base_pts_to_rhand_joints = rhand_joints.unsqueeze(-2) - base_pts.unsqueeze(1).unsqueeze(1) # bsz x nf x nnj x nnb x 3 
            else:
                denormed_rel_base_pts_to_rhand_joints = rhand_joints.unsqueeze(-2) - base_pts.unsqueeze(2) # bsz x nf x nnj x nnb x 3 
            
            
            ##### E ######
            # ### Calculate moving related energies ###
            # # denormed_rhand_joints: bsz x nf x nnj x 3 -> denormed rhand joints here #
            # # pert rhand joints # ## denormed base pts to rhand joints -> denormed rel positions ##
            # # denormed_rel_base_pts_to_rhand_joints = rel_base_pts_to_rhand_joints * x_start['per_frame_std_joints_rel'] + x_start['per_frame_avg_joints_rel']
            # # denormed_rhand_joints = rhand_joints * extents_rhand_joints.unsqueeze(1) + avg_exp_rhand_joints.unsqueeze(1)
            # denormed_dist_rel_base_pts_to_pert_rhand_joints = torch.sum(
            #     denormed_rel_base_pts_to_rhand_joints * base_normals.unsqueeze(1).unsqueeze(1), dim=-1 # denormed relative distances 
            # ) ## l2 real base pts 
            # k_f = 1. ## l2 rel base pts to pert rhand joints ##
            # # l2_rel_base_pts_to_pert_rhand_joints: bsz x nf x nnj x nnb #
            # l2_rel_base_pts_to_pert_rhand_joints = torch.norm(denormed_rel_base_pts_to_rhand_joints, dim=-1)
            # ### att_forces ##
            # att_forces = torch.exp(-k_f * l2_rel_base_pts_to_pert_rhand_joints) # bsz x nf x nnj x nnb #
            # # bsz x (ws - 1) x nnj x nnb #
            # att_forces = att_forces[:, :-1, :, :] # attraction forces -1 #
            # # rhand_joints: ws x nnj x 3 # -> (ws - 1) x nnj x 3 ## rhand_joints ##
            # # bsz x (ws - 1) x nnj x 3 --> displacements s#
            # denormed_rhand_joints_disp = rhand_joints[:, 1:, :, :] - rhand_joints[:, :-1, :, :]
            # # distance -- base_normalss,; (ws - 1) x nnj x nnb x 3 --> bsz x (ws - 1) x nnj x nnb # 
            # # signed_dist_base_pts_to_pert_rhand_joints_along_normal # bsz x (ws - 1) x nnj x nnb #
            # signed_dist_base_pts_to_rhand_joints_along_normal = torch.sum(
            #     base_normals.unsqueeze(1).unsqueeze(1) * denormed_rhand_joints_disp.unsqueeze(-2), dim=-1
            # ) ## signed dist base pts to rhand joints along normals #
            # # rel_base_pts_to_pert_rhand_joints_vt_normal: bsz x (ws -1) x nnj x nnb x 3 -> the relative positions vertical to base normals #
            # rel_base_pts_to_rhand_joints_vt_normal = denormed_rhand_joints_disp.unsqueeze(-2)  - signed_dist_base_pts_to_rhand_joints_along_normal.unsqueeze(-1) * base_normals.unsqueeze(1).unsqueeze(1)
            # dist_base_pts_to_rhand_joints_vt_normal = torch.sqrt(torch.sum(
            #     rel_base_pts_to_rhand_joints_vt_normal ** 2, dim=-1
            # ))
            # k_a = 1.
            # k_b = 1.
            ### 
            
            # e_disp_rel_to_base_along_normals = k_a * att_forces * torch.abs(signed_dist_base_pts_to_rhand_joints_along_normal)
            # # (ws - 1) x nnj x nnb # -> dist vt normals # ## 
            # e_disp_rel_to_baes_vt_normals = k_b * att_forces * dist_base_pts_to_rhand_joints_vt_normal
            # # nf x nnj x nnb ---> dist_vt_normals -> nf x nnj x nnb # # torch.sqrt() ##
            # 
            ##### E ######
            # dist_base_pts_to_pert_rhand_joints, rel_base_pts_to_pert_rhand_joints #
            # e_disp_rel_to_base_along_normals, e_disp_rel_to_baes_vt_normals #
            # print("per frmae avg jalong normals,", x_start['per_frame_avg_disp_along_normals'].size(), "per frame std along normals",  x_start['per_frame_std_disp_along_normals'].size(), "e_disp_rel_to_base_along_normals", e_disp_rel_to_base_along_normals.size(), "e_disp_rel_to_baes_vt_normals", e_disp_rel_to_baes_vt_normals.size())
            # per_frame_avg_disp_along_normals, per_frame_std_disp_along_normals
            # per_frame_avg_disp_vt_normals, per_frame_std_disp_vt_normals
            # e_disp_rel_to_base_along_normals = (e_disp_rel_to_base_along_normals - x_start['per_frame_avg_disp_along_normals']) / x_start['per_frame_std_disp_along_normals']
            # e_disp_rel_to_baes_vt_normals = (e_disp_rel_to_baes_vt_normals - x_start['per_frame_avg_disp_vt_normals']) / x_start['per_frame_std_disp_vt_normals']

        else:  
            raise ValueError(f"Unrecognized denoising_stra: {self.denoising_stra}")
        ''' GET rel and dists '''

        # noise_rel_base_pts_to_rhand_joints, rel_base_pts_to_rhand_joints
        ## pred jts # pred jts #
        ### Add noise to rel_baes_pts_to_rhand_joints ###
        noise_rel_base_pts_to_rhand_joints = th.randn_like(rel_base_pts_to_rhand_joints) ## bsz x ws x nnjts x 3 -> for each joint point #
        pert_rel_base_pts_to_rhand_joints = self.q_sample( ## bsz x ws x nnjts x 3 --> pert in the spatial space ##
            rel_base_pts_to_rhand_joints, t, noise_rel_base_pts_to_rhand_joints
        )
        
        # add_noise_onjts, add_noise_onjts_single
        ### ==== add noise on joints ==== ###
        ### ==== add noise on joints and then use them to calculate the perturbed rel-base-pts-to-rhand-joints ==== ###
        if self.args.add_noise_onjts: # add_noise_onjts_single # -->  # bsz x nf x nnjts x nn_base_pts x 3
            joints_offset_output_exp = joints_offset_sequence.unsqueeze(-2).repeat(1, 1, 1, normed_base_pts.size(-2), 1)
            noise_rel_base_pts_to_rhand_joints = torch.randn_like(joints_offset_output_exp)
            pert_joints_offset_output_exp = self.q_sample( ## bsz x ws x nnjts x 3 ##
                joints_offset_output_exp, t, noise_rel_base_pts_to_rhand_joints
            )
            # pert_rel_base_pts_to_rhand_joints: bsz x seq_len x nnj x nnb x 3 --> the rhand-joints; 
            if not self.args.use_arti_obj:
                pert_rel_base_pts_to_rhand_joints = pert_joints_offset_output_exp - normed_base_pts.unsqueeze(1).unsqueeze(1)  # 
            else:
                pert_rel_base_pts_to_rhand_joints = pert_joints_offset_output_exp - normed_base_pts.unsqueeze(2)
        elif self.args.add_noise_onjts_single: # joints offset sequence # joints offset single #
            noise_joints_offset_output = torch.randn_like(joints_offset_sequence)
            pert_joints_offset_output = self.q_sample( ## bsz x ws x nnjts x 3 --> pert in the spatial space ###
                joints_offset_sequence, t, noise_joints_offset_output
            )
            if not self.args.use_arti_obj:
                pert_rel_base_pts_to_rhand_joints = pert_joints_offset_output.unsqueeze(-2) - normed_base_pts.unsqueeze(1).unsqueeze(1) 
            else:
                pert_rel_base_pts_to_rhand_joints = pert_joints_offset_output.unsqueeze(-2) - normed_base_pts.unsqueeze(2)
            noise_rel_base_pts_to_rhand_joints = noise_joints_offset_output.unsqueeze(-2).repeat(1, 1, 1, normed_base_pts.size(1), 1)
            
            # joints_offset_output_exp = joints_offset_sequence.unsqueeze(-2).repeat(1, 1, 1, normed_base_pts.size(1), 1)
            # noise_rel_base_pts_to_rhand_joints = torch.randn_like(joints_offset_output_exp)
            # noise_rel_base_pts_to_rhand_joints = noise_rel_base_pts_to_rhand_joints[:, :, :, 0:1, :].repeat(1, 1, 1, normed_base_pts.size(1), 1)
            # pert_joints_offset_output_exp = self.q_sample( ## bsz x ws x nnjts x 3 --> pert in the spatial space ###
            #     joints_offset_output_exp, t, noise_rel_base_pts_to_rhand_joints
            # )
            # pert_rel_base_pts_to_rhand_joints = pert_joints_offset_output_exp - normed_base_pts.unsqueeze(1).unsqueeze(1) 
        
        
        if self.args.train_enc:
            pert_rel_base_pts_to_rhand_joints = rel_base_pts_to_rhand_joints
        
        # # bsz x ws x nnj x nnb x 3 
        # maxx_pert_basejtsrel, _ = torch.max(pert_rel_base_pts_to_rhand_joints.view(-1, 3), dim=0)
        # minn_pert_basejtsrel, _ = torch.min(pert_rel_base_pts_to_rhand_joints.view(-1, 3), dim=0)
        # maxx_basejtsrel, _ = torch.max(rel_base_pts_to_rhand_joints.view(-1, 3), dim=0)
        # minn_basejtsrel, _ = torch.min(rel_base_pts_to_rhand_joints.view(-1, 3), dim=0)
        # # print(f"maxx_pert_basejtsrel: {maxx_pert_basejtsrel}, minn_pert_basejtsrel: {minn_pert_basejtsrel}, maxx_basejtsrel: {maxx_basejtsrel}, minn_basejtsrel: {minn_basejtsrel}")
        
        ### Add noise to avg joints sequence #
        noise_avg_joints_sequence = th.randn_like(avg_joints_sequence)
        pert_avg_joints_sequence = self.q_sample( # bsz x nnjts x 3 --> the avg joitns for each ... 
            avg_joints_sequence, t, noise_avg_joints_sequence
        )
        
        ### perturbe offset-joints ###
        # joints_offset_sequence
        noise_joints_offset_sequence = th.randn_like(joints_offset_sequence)
        pert_joints_offset_sequence = self.q_sample(
            joints_offset_sequence, t, noise_joints_offset_sequence
        )
        ### perturbe offset-joints ###
        
        # rel bae pts to
        if not self.args.use_arti_obj:
            pert_rel_base_pts_to_joints_for_jts_pred = pert_joints_offset_sequence.unsqueeze(-2) - normed_base_pts.unsqueeze(1).unsqueeze(1) 
        else:
            pert_rel_base_pts_to_joints_for_jts_pred = pert_joints_offset_sequence.unsqueeze(-2) - normed_base_pts.unsqueeze(2)

        if self.args.use_jts_pert_realbasejtsrel:
            
            pert_rel_base_pts_to_rhand_joints = pert_rel_base_pts_to_joints_for_jts_pred
            noise_rel_base_pts_to_rhand_joints = noise_joints_offset_sequence.unsqueeze(-2).repeat(1, 1, 1, normed_base_pts.size(-2), 1).contiguous()
        
        if self.args.finetune_with_cond:
            if not self.args.use_arti_obj:
                pert_rel_base_pts_to_rhand_joints = pert_joints_offset_sequence.unsqueeze(-2) - normed_base_pts.unsqueeze(1).unsqueeze(1) 
            else:
                pert_rel_base_pts_to_rhand_joints = pert_joints_offset_sequence.unsqueeze(-2) - normed_base_pts.unsqueeze(2)
        
        
        input_data = {
          'base_pts': base_pts.clone(), # base pts ###
          'base_normals': base_normals.clone(), # base normals ### 
          'rel_base_pts_to_rhand_joints': rel_base_pts_to_rhand_joints.clone(), 
          'rhand_joints': rhand_joints,
          'avg_joints_sequence': avg_joints_sequence, ## bsz x nnjoints x 3 here for the avg_joints ##
          'pert_avg_joints_sequence': pert_avg_joints_sequence,
          'pert_rel_base_pts_to_rhand_joints': pert_rel_base_pts_to_rhand_joints, ## pert realtive base pts to rhand joints ##
          'pert_joints_offset_sequence': pert_joints_offset_sequence,
          'normed_base_pts': normed_base_pts,
          'pert_rel_base_pts_to_joints_for_jts_pred': pert_rel_base_pts_to_joints_for_jts_pred,
        }
        # if 'sampled_base_pts_nearest_obj_pc' in x_start:
        #     input_data.update(ambient_xstart_dict)
        # dist_base_pts_to_pert_rhand_joints, rel_base_pts_to_pert_rhand_joints #
        # e_disp_rel_to_base_along_normals, e_disp_rel_to_baes_vt_normals #
        # bsz x ws - 1 x nnj x nnb # # input_data 
        
        ##### E ######
        # input_data.update(
        #     {
        #         # e_disp_rel_to_base_along_normals, e_disp_rel_to_baes_vt_normals
        #         ### e_disp_rel_to_base_along_normals: 
        #         'e_disp_rel_to_base_along_normals': e_disp_rel_to_base_along_normals, 
        #         'e_disp_rel_to_baes_vt_normals': e_disp_rel_to_baes_vt_normals,
        #         'pert_e_disp_rel_to_base_along_normals': pert_e_disp_rel_to_base_along_normals,
        #         'pert_e_disp_rel_to_base_vt_normals': pert_e_disp_rel_to_base_vt_normals,
        #     }
        # )
        ##### E ######
        
        # input_data.update(
        #   {k: x_start[k].clone() for k in x_start if k not in input_data}
        # ) # gaussian diffusion ours ## 
        # rel_base_pts_to_rhand_joints in the input_data #
        if model_kwargs is None:
            model_kwargs = {}

        terms = {} # latents in the latent space # # sequence latents #
        
        
        # if self.args.train_diff:
        #     with torch.no_grad():
        #         out_dict = model(input_data, self._scale_timesteps(t).clone())
        # else:
        # clean_joint_seq_latents: seq x bs x d #
        # clean_joint_seq_latents = model(input_data, self._scale_timesteps(t).clone()) ## 
        ### the strategy of removing noise from corresponding quantities ###
        out_dict = model(input_data, self._scale_timesteps(t).clone())
        ### get model output dictionary ###
        
        KL_loss = 0.
        
        terms['rot_mse'] = 0.
        
        ### diff_jts ###
        # out dict of the #
        # reumse checkpoints  #dec_in_dict
        dec_in_dict = {}

          
        if self.diff_realbasejtsrel: # 
            real_dec_basejtsrel = out_dict['real_dec_basejtsrel'] # bsz x nf x nnj x nnb x 3
            # noise_rel_base_pts_to_rhand_joints, rel_base_pts_to_rhand_joints
            if self.args.pred_diff_noise and not self.args.train_enc:
                # print(f"here predicting diff_noise...")
                if self.args.use_jts_pert_realbasejtsrel:
                    # print(f"use_jts_pert_realbasejtsrel!!!")
                    jts_pred_loss = torch.sum((
                        real_dec_basejtsrel[:, :, :, 0:1, :] - noise_rel_base_pts_to_rhand_joints[:, :, :, 0:1, :]
                    ) ** 2, dim=-1 ).mean(dim=-1).mean(dim=-1).mean(dim=-1)
                    # jts_pred_loss = torch.sum((
                    #     real_dec_basejtsrel  - noise_joints_offset_sequence
                    # ) ** 2, dim=-1 ).mean(dim=-1).mean(dim=-1)
                else:
                    jts_pred_loss = torch.sum((
                        real_dec_basejtsrel - noise_rel_base_pts_to_rhand_joints
                    ) ** 2, dim=-1).mean(dim=-1).mean(dim=-1).mean(dim=-1)
            else:
                jts_pred_loss = torch.sum((
                    real_dec_basejtsrel - rel_base_pts_to_rhand_joints
                ) ** 2, dim=-1).mean(dim=-1).mean(dim=-1).mean(dim=-1)
            terms['jts_pred_loss'] = jts_pred_loss
            terms['rot_mse'] += jts_pred_loss
            
            if self.args.train_enc:
                obj_base_pts_feats = out_dict['obj_base_pts_feats'].detach()
                noise_obj_base_pts_feats = th.randn_like(obj_base_pts_feats) ## bsz x ws x nnjts x 3 -> for each joint point #
                pert_obj_base_pts_feats = self.q_sample( ## bsz x ws x nnjts x 3 --> pert in the spatial space ###
                    obj_base_pts_feats.permute(1, 0, 2), t, noise_obj_base_pts_feats.permute(1, 0, 2)
                ).permute(1, 0, 2)
                # seq_len x bsz x nn_feats_dim #
                dec_obj_base_pts_feats = model.model.denoising_realbasejtsrel_objbasefeats( pert_obj_base_pts_feats,  self._scale_timesteps(t).clone())
                if self.args.pred_diff_noise:
                    obj_base_pts_feats_denoising_loss = torch.sum(
                        (dec_obj_base_pts_feats - noise_obj_base_pts_feats) ** 2, dim=-1
                    ) / noise_obj_base_pts_feats.size(-1)
                    obj_base_pts_feats_denoising_loss = obj_base_pts_feats_denoising_loss.transpose(0, 1).mean(dim=-1)
                else:
                    obj_base_pts_feats_denoising_loss = torch.sum(
                        (dec_obj_base_pts_feats - obj_base_pts_feats) ** 2, dim=-1
                    ) / obj_base_pts_feats.size(-1)
                    obj_base_pts_feats_denoising_loss = obj_base_pts_feats_denoising_loss.transpose(0, 1).mean(dim=-1)
                
                terms['jts_latent_denoising_loss'] = obj_base_pts_feats_denoising_loss
                terms['rot_mse'] += obj_base_pts_feats_denoising_loss
            
        if self.diff_basejtsrel:
            
            if 'basejtsrel_output' in out_dict:
            
                basejtsrel_output = out_dict['basejtsrel_output'].transpose(-2, -3).contiguous()
                avg_jts_outputs = out_dict['avg_jts_outputs']
            
                # print(f"basejtsrel_output: {basejtsrel_output.size()}, noise_rel_base_pts_to_rhand_joints: {noise_rel_base_pts_to_rhand_joints.size()}, rel_base_pts_to_rhand_joints: {rel_base_pts_to_rhand_joints.size()}")
                
                if self.args.pred_diff_noise:
                    basejtsrel_denoising_loss = torch.sum((basejtsrel_output - noise_rel_base_pts_to_rhand_joints) ** 2, dim=-1).mean(dim=-1).mean(dim=-1).mean(dim=-1)
                    avgjts_denoising_loss = torch.sum((avg_jts_outputs - noise_avg_joints_sequence) ** 2, dim=-1).mean(dim=-1)
                else:
                    basejtsrel_denoising_loss = torch.sum((basejtsrel_output - rel_base_pts_to_rhand_joints) ** 2, dim=-1).mean(dim=-1).mean(dim=-1).mean(dim=-1)
                    avgjts_denoising_loss = torch.sum((avg_jts_outputs - avg_joints_sequence) ** 2, dim=-1).mean(dim=-1)
            else:
                joints_offset_output = out_dict['joints_offset_output']
                if self.args.pred_diff_noise:
                    basejtsrel_denoising_loss = torch.sum((joints_offset_output - noise_joints_offset_sequence) ** 2, dim=-1).mean(dim=-1).mean(dim=-1) # bsz x ws x nnjts x 3 --> mean and mena over dim=-1
                else: # # basejtsrel denoising losses ##
                    basejtsrel_denoising_loss = torch.sum((joints_offset_output - joints_offset_sequence) ** 2, dim=-1).mean(dim=-1).mean(dim=-1) # 
                if 'avg_jts_outputs' in out_dict: # avg jts outputs ##
                    avg_jts_outputs = out_dict['avg_jts_outputs']
                    if self.args.pred_diff_noise:
                        avgjts_denoising_loss = torch.sum((avg_jts_outputs - noise_avg_joints_sequence) ** 2, dim=-1).mean(dim=-1)
                    else:
                        avgjts_denoising_loss = torch.sum((avg_jts_outputs - avg_joints_sequence) ** 2, dim=-1).mean(dim=-1)
                else:
                    avgjts_denoising_loss = torch.zeros_like(basejtsrel_denoising_loss)
                
            terms['basejtrel_denoising_loss'] = basejtsrel_denoising_loss
            terms['avgjts_denoising_loss'] = avgjts_denoising_loss  # # 
            terms['rot_mse'] += basejtsrel_denoising_loss + avgjts_denoising_loss # jts denoising ##
     
        # sv_out_in = {
        #     'model_output': model_output.detach().cpu().numpy(),
        #     'target': target.detach().cpu().numpy(),
        #     't': t.detach().cpu().numpy(),
        # }
        import os
        import datetime
        cur_time_stamp = datetime.datetime.now().timestamp()
        cur_time_stamp = str(cur_time_stamp)
        # sv_out_fn = os.path.join(sv_dir_rt, f"out_{cur_time_stamp}.npy")
        # np.save(sv_out_fn,sv_inter_dict )
        # print(f"Samples saved to {sv_out_fn}")
        target_xyz, model_output_xyz = None, None


        terms["loss"] = terms["rot_mse"]

        return terms



    def _prior_bpd(self, x_start):
        """
        Get the prior KL term for the variational lower-bound, measured in
        bits-per-dim.

        This term can't be optimized, as it only depends on the encoder.

        :param x_start: the [N x C x ...] tensor of inputs.
        :return: a batch of [N] KL values (in bits), one per batch element.
        """
        batch_size = x_start.shape[0]
        t = th.tensor([self.num_timesteps - 1] * batch_size, device=x_start.device)
        qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t)
        kl_prior = normal_kl(
            mean1=qt_mean, logvar1=qt_log_variance, mean2=0.0, logvar2=0.0
        )
        return mean_flat(kl_prior) / np.log(2.0)




def _extract_into_tensor(arr, timesteps, broadcast_shape):
    """
    Extract values from a 1-D numpy array for a batch of indices.

    :param arr: the 1-D numpy array.
    :param timesteps: a tensor of indices into the array to extract.
    :param broadcast_shape: a larger shape of K dimensions with the batch
                            dimension equal to the length of timesteps.
    :return: a tensor of shape [batch_size, 1, ...] where the shape has K dims.
    """
    res = th.from_numpy(arr).to(device=timesteps.device)[timesteps].float()
    while len(res.shape) < len(broadcast_shape):
        res = res[..., None]
    return res.expand(broadcast_shape)
