from __future__ import annotations

import sys
sys.path.insert(0, sys.path[0]+r"/../")

import os
import pdb
import random
import time
from typing import Literal
from dataclasses import dataclass, asdict, make_dataclass, field

import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.cuda import amp
import tyro
import yaml
from torch.distributions.normal import Normal
from torch.utils.tensorboard import SummaryWriter
from pathlib import Path
from tqdm import tqdm
import pickle
import json
import copy
import clip

from mld.train_mvae import Args as MVAEArgs
from mld.train_mvae import DataArgs, TrainArgs
from model.mld_denoiser import DenoiserMLP, DenoiserTransformer
from model.mld_vae import AutoMldVae
from model.text_encoder import CLIPTextEncoder, CLIPTextEncoderV2, CLIPTextEncoderV3, CLIPTextEncoderV4
from data_loaders.humanml.data.dataset import PrimitiveSequenceDataset, WeightedPrimitiveSequenceDataset, WeightedPrimitiveSequenceDatasetV2
from data_loaders.humanml.data.dataset_hml3d import HML3dDataset, HML3dDatasetV2, HML3dD262Dataset
from data_loaders.HHI.data.dataset_interhuman import InterHumanDataset, InterHumanDatasetV2, InterHumanMotionV2, InterHumanDatasetV3, InterHumanDatasetV4, InterGenDataset, InterGenDatasetWPE
from data_loaders.HHI.data.dataset_interx import InterXDataset, InterXDatasetWPE
from utils.smpl_utils import get_smplx_param_from_6d
from pytorch3d import transforms
from diffusion import gaussian_diffusion as gd
from diffusion.respace import SpacedDiffusion, space_timesteps
from diffusion.resample import create_named_schedule_sampler
from utils.intergen_util import cal_rel_rot as cal_rel_rot_from_joints
from utils.interx_util import cal_rel_rot_from_matrices
from utils.model_util import CosineWarmupScheduler


debug = 0

kinematic_chain = [[0, 2, 5, 8, 11],
                 [0, 1, 4, 7, 10],
                 [0, 3, 6, 9, 12, 15],
                 [9, 14, 17, 19, 21],
                 [9, 13, 16, 18, 20]]

@dataclass
class DiffusionArgs:
    diffusion_steps: int = 10
    # diffusion_steps: int = 100
    respacing: str = ''
    noise_schedule: Literal['linear', 'cosine'] = 'cosine'
    sigma_small: bool = True
    predict_xstart: bool = True

@dataclass
class DenoiserMLPArgs:
    h_dim: int = 512
    n_blocks: int = 2
    dropout: float = 0.1
    activation: str = "gelu"

    cond_mask_prob: float = 0.1

    clip_dim: int = 512
    history_shape: tuple = (2, 276)
    noise_shape: tuple = (1, 128)

@dataclass
class DenoiserTransformerArgs:
    h_dim: int = 512
    ff_size: int = 1024
    num_layers: int = 8
    num_heads: int = 4
    dropout: float = 0.1
    activation: str = "gelu"

    cond_mask_prob: float = 0.1

    clip_dim: int = 512
    history_shape: tuple = (2, 276)
    noise_shape: tuple = (1, 128)

@dataclass
class DenoiserInterGenArgs:
    h_dim: int = 512
    ff_size: int = 1024
    num_layers: int = 8
    num_heads: int = 4
    dropout: float = 0.1
    activation: str = "gelu"

    cond_mask_prob: float = 0.1

    clip_dim: int = 512
    history_shape: tuple = (2, 276)
    noise_shape: tuple = (1, 128)

@dataclass
class DenoiserInteractionArgs:
    h_dim: int = 512
    ff_size: int = 1024
    num_layers: int = 8
    num_heads: int = 4
    dropout: float = 0.1
    activation: str = "gelu"

    cond_mask_prob: float = 0.1

    noise_shape: tuple = (8, 31)

@dataclass
class DenoiserArgs:
    mvae_path: str = './mvae/mvae_interhuman_sep_latent256/checkpoint_300000.pt'
    rescale_latent: int = 1

    train_rollout_type: Literal["single", "full"] = "single"
    train_rollout_history: str = "gt"  # "rollout" or "gt"
    
    use_indi_text: bool = True  # whether to use individual text embedding for each person
    
    text_ca: bool = field(default=False, metadata={"help": "Enable text cross attention."})
    text_sep: bool = field(default=False, metadata={"help": "Enable text separation."})
    sep_mode: int = 0  # 0 or 1
    
    his_mask_prob: float = -1
    interaction_mask_prob: float = 0.1
    shared_mask: bool = True
    
    use_controlnet: bool = False
    text_first: bool = True   # text first or not, only for text_ca
    
    use_inter: bool = False
    inter_first: bool = True  # interaction first or not, only for inter_ca
    
    normalize_relpose: bool = False  # whether to normalize the relative pose in the interaction model
    
    merge_his_relpose: bool = False  # whether to merge the history with the current relative pose
    
    attention_sep: bool = False  # whether to use separate attention for history and interaction text
    
    use_pre_latent: bool = False  # whether to use previous latent for the denoiser
    pre_max_len: int = 37  # max length of the previous latent, only used when use_pre_latent is True
    
    ffn_last: bool = False  # whether to use ffn last in the transformer
    init_zero: bool = True  # whether to initialize the linear layers to zero

    model_type: str = "intergen"
    model_args: DenoiserMLPArgs | DenoiserTransformerArgs | DenoiserInterGenArgs = DenoiserInterGenArgs()

    diffusion_args: DiffusionArgs = DiffusionArgs()
    
    clip_version: str = 'ViT-B/32'  # version of CLIP model to use, default is ViT-B/32
    clip_device: str = 'cuda:0'  # device to use for CLIP model, default is cuda:0
    load_text_embedding: bool = True  # whether to load the text embedding from the dataset
    text_encoder_version: str = "v1"  # version of the text encoder to use, default is v1
    clip_final_proj: bool = False  # whether to use the final projection of the CLIP text encoder
    
    use_extra_pe: bool = False  # whether to use extra positional encoding for the denoiser
    use_step_pe: bool = True  # whether to use step positional encoding for the denoiser
    merge_his: bool = True  # whether to merge the history with the current latent in the denoiser
    merge_partner_his: bool = True  # whether to merge the partner's history with the current latent in the denoiser

@dataclass
class MLDArgs:
    train_args: TrainArgs = TrainArgs()
    data_args: DataArgs = DataArgs()
    denoiser_args: DenoiserArgs = DenoiserArgs()
    interaction_args: DenoiserInteractionArgs = DenoiserInteractionArgs()
    
    padding: bool = False
    
    use_interaction_model: bool = False
    
    use_interaction_loss: bool = False  # use interaction loss or not
    scale_interloss_timestep: float = 1.0  # loss schedule
    
    finetune: bool = False
    freeze_single_ckpt: str = None

    exp_name: str = "mld_interhuman_intergen"
    seed: int = 0
    torch_deterministic: bool = True
    device: str = "cuda:0"
    save_dir: str = "./mld_denoiser"

    track: int = 0
    wandb_project_name: str = "mld_denoiser"
    wandb_entity: str = "Motion_Generation"
    
    warmup_steps: int = 0
    
    do_eval: bool = False
    eval_every_n_steps: int = 10000
    eval_save: bool = False
    
    react_prob: float = 0.2 # probability of reaction generation during training, -1 means no reaction generation


def create_gaussian_diffusion(args, enable_ddim=True):
    # default params
    predict_xstart = args.predict_xstart  # we always predict x_start (a.k.a. x0), that's our deal!
    steps = args.diffusion_steps
    scale_beta = 1.  # no scaling
    timestep_respacing = args.respacing if enable_ddim else ''  # can be used for ddim sampling, we don't use it.
    learn_sigma = False
    rescale_timesteps = False

    betas = gd.get_named_beta_schedule(args.noise_schedule, steps, scale_beta)
    loss_type = gd.LossType.MSE

    if not timestep_respacing:
        timestep_respacing = [steps]

    return SpacedDiffusion(
        use_timesteps=space_timesteps(steps, timestep_respacing),
        betas=betas,
        model_mean_type=(
            gd.ModelMeanType.EPSILON if not predict_xstart else gd.ModelMeanType.START_X
        ),
        model_var_type=(
            (
                gd.ModelVarType.FIXED_LARGE
                if not args.sigma_small
                else gd.ModelVarType.FIXED_SMALL
            )
            if not learn_sigma
            else gd.ModelVarType.LEARNED_RANGE
        ),
        loss_type=loss_type,
        rescale_timesteps=rescale_timesteps,
    )


class ClassifierFreeWrapper(nn.Module):
    def __init__(self, model):
        super().__init__()
        self.model = model  # model is the actual model to run

        assert self.model.cond_mask_prob > 0, 'Cannot run a guided diffusion on a model that has not been trained with no conditions'

    def forward(self, x, timesteps, y=None):
        y['uncond'] = False
        out = self.model(x, timesteps, y)
        y_uncond = y
        y_uncond['uncond'] = True
        out_uncond = self.model(x, timesteps, y_uncond)
        # print('scale:', y['scale'])
        return out_uncond + (y['scale'] * (out - out_uncond))


def load_mld(denoiser_checkpoint, device, pre_max_len=None):
    # load denoiser
    denoiser_dir = Path(denoiser_checkpoint).parent
    with open(denoiser_dir / "args.yaml", "r") as f:
        args = tyro.extras.from_yaml(MLDArgs, yaml.safe_load(f))
    denoiser_args = args.denoiser_args
    data_args = args.data_args
    data_args.padding = args.padding
    # load mvae model and freeze
    print('denoiser model type:', denoiser_args.model_type)
    print('denoiser model args:', asdict(denoiser_args.model_args))
    
    if denoiser_args.model_type == "mlp":
        denoiser_class = DenoiserMLP
    elif denoiser_args.model_type == "transformer":
        denoiser_class = DenoiserTransformer
    
    model_kwargs = asdict(denoiser_args.model_args)
    model_kwargs['text_ca'] = denoiser_args.text_ca
    model_kwargs['text_sep'] = denoiser_args.text_sep
    model_kwargs['his_mask_prob'] = denoiser_args.his_mask_prob
    model_kwargs['interaction_mask_prob'] = denoiser_args.interaction_mask_prob
    model_kwargs['shared_mask'] = denoiser_args.shared_mask
    model_kwargs['sep_mode'] = denoiser_args.sep_mode if hasattr(denoiser_args, 'sep_mode') else 0
    model_kwargs['use_inter'] = denoiser_args.use_inter if hasattr(denoiser_args, 'use_inter') else False
    model_kwargs['text_first'] = denoiser_args.text_first if hasattr(denoiser_args, 'text_first') else True
    model_kwargs['inter_first'] = denoiser_args.inter_first if hasattr(denoiser_args, 'inter_first') else True
    model_kwargs['use_controlnet'] = denoiser_args.use_controlnet if hasattr(denoiser_args, 'use_controlnet') else False
    model_kwargs['attention_sep'] = denoiser_args.attention_sep if hasattr(denoiser_args, 'attention_sep') else False
    model_kwargs['use_pre_latent'] = denoiser_args.use_pre_latent if hasattr(denoiser_args, 'use_pre_latent') else False
    if pre_max_len is not None:
        model_kwargs['pre_max_len'] = pre_max_len
    model_kwargs['ffn_last'] = denoiser_args.ffn_last if hasattr(denoiser_args, 'ffn_last') else False
    model_kwargs['use_indi_text'] = denoiser_args.use_indi_text if hasattr(denoiser_args, 'use_indi_text') else True
    model_kwargs['use_extra_pe'] = denoiser_args.use_extra_pe if hasattr(denoiser_args, 'use_extra_pe') else False
    model_kwargs['use_step_pe'] = denoiser_args.use_step_pe if hasattr(denoiser_args, 'use_step_pe') else True
    model_kwargs['merge_his'] = denoiser_args.merge_his if hasattr(denoiser_args, 'merge_his') else True
    model_kwargs['merge_partner_his'] = denoiser_args.merge_partner_his if hasattr(denoiser_args, 'merge_partner_his') else True
    denoiser_model = denoiser_class(**model_kwargs).to(device)
    # denoiser_model = denoiser_class(
    #     **asdict(denoiser_args.model_args),
    # ).to(device)
    checkpoint = torch.load(denoiser_checkpoint, map_location=device)
    model_state_dict = checkpoint['model_state_dict']
    print(f"Loading denoiser checkpoint from {denoiser_checkpoint}")
    denoiser_model.load_state_dict(model_state_dict, strict=False)
    for param in denoiser_model.parameters():
        param.requires_grad = False
    denoiser_model.eval()
    denoiser_model = ClassifierFreeWrapper(denoiser_model)

    # load vae
    vae_checkpoint = denoiser_args.mvae_path
    vae_dir = Path(vae_checkpoint).parent
    with open(vae_dir / "args.yaml", "r") as f:
        vae_args = tyro.extras.from_yaml(MVAEArgs, yaml.safe_load(f))
    # load mvae model and freeze
    print('vae model args:', asdict(vae_args.model_args))
    vae_model = AutoMldVae(
        **asdict(vae_args.model_args),
    ).to(device)
    checkpoint = torch.load(denoiser_args.mvae_path, map_location=device)
    model_state_dict = checkpoint['model_state_dict']
    if 'latent_mean' not in model_state_dict:
        model_state_dict['latent_mean'] = torch.tensor(0)
    if 'latent_std' not in model_state_dict:
        model_state_dict['latent_std'] = torch.tensor(1)
    vae_model.load_state_dict(model_state_dict)
    vae_model.latent_mean = model_state_dict[
        'latent_mean']  # register buffer seems to be not loaded by load_state_dict
    vae_model.latent_std = model_state_dict['latent_std']
    print(f"Loading vae checkpoint from {denoiser_args.mvae_path}")
    print(f"latent_mean: {vae_model.latent_mean}")
    print(f"latent_std: {vae_model.latent_std}")
    for param in vae_model.parameters():
        param.requires_grad = False
    vae_model.eval()
    return denoiser_args, denoiser_model, vae_args, vae_model, data_args

def load_mvae(mvae_ckpt, device):
    mvae_checkpoint_dir = Path(mvae_ckpt).parent
    arg_path = mvae_checkpoint_dir / "args.yaml"
    with open(arg_path, "r") as f:
        mvae_args = tyro.extras.from_yaml(MVAEArgs, yaml.safe_load(f))
    
    # load mvae model and freeze
    print('vae model args:', asdict(mvae_args.model_args))
    vae_model = AutoMldVae(
        **asdict(mvae_args.model_args),
    ).to(device)
    
    checkpoint = torch.load(mvae_ckpt, map_location=device)
    model_state_dict = checkpoint['model_state_dict']
    if 'latent_mean' not in model_state_dict:
        model_state_dict['latent_mean'] = torch.tensor(0)
    if 'latent_std' not in model_state_dict:
        model_state_dict['latent_std'] = torch.tensor(1)
    vae_model.load_state_dict(model_state_dict)
    vae_model.latent_mean = model_state_dict['latent_mean']  # register buffer seems to be not loaded by load_state_dict
    vae_model.latent_std = model_state_dict['latent_std']
    print(f"Loading vae checkpoint from {mvae_ckpt}")
    print(f"latent_mean: {vae_model.latent_mean}")
    print(f"latent_std: {vae_model.latent_std}")
    for param in vae_model.parameters():
        param.requires_grad = False
    vae_model.eval()
    return mvae_args, vae_model

class Trainer:
    def __init__(self, args):
        self.args = args
        args.save_dir = Path(args.save_dir) / args.exp_name
        args.save_dir.mkdir(parents=True, exist_ok=True)
        train_args = args.train_args
        data_args = args.data_args
        denoiser_args = args.denoiser_args
        
        self.use_extra_pe = denoiser_args.use_extra_pe
        
        self.min_lr = 2e-6
        self.fids = [7, 10, 8, 11]
        self.feet_thre = 0.001

        # TRY NOT TO MODIFY: seeding
        random.seed(args.seed)
        np.random.seed(args.seed)
        torch.manual_seed(args.seed)
        torch.set_default_dtype(torch.float32)
        torch.backends.cudnn.deterministic = args.torch_deterministic
        device = torch.device(args.device if torch.cuda.is_available() else "cpu")

        # load dataset
        if data_args.dataset == 'mp_seq_v2':
            dataset_class = WeightedPrimitiveSequenceDatasetV2
        elif data_args.dataset == 'hml3d':
            dataset_class = HML3dDatasetV2
        elif data_args.dataset == 'hml3d_d262':
            dataset_class = HML3dD262Dataset
        elif data_args.dataset == 'interhuman':
            dataset_class = InterHumanDatasetV4
        elif data_args.dataset == 'interhuman_single':
            dataset_class = InterHumanMotionV2
        elif data_args.dataset == 'interhuman_d262':
            dataset_class = InterGenDataset
        elif data_args.dataset == 'interhuman_d262_wpe':
            dataset_class = InterGenDatasetWPE
        elif data_args.dataset == 'interx':
            dataset_class = InterXDataset
        elif data_args.dataset == 'interx_wpe':
            dataset_class = InterXDatasetWPE
        else:
            dataset_class = WeightedPrimitiveSequenceDataset
        mode = 'merged' if self.args.data_args.interaction else 'sep'
        data_args.enforce_gender = None if data_args.enforce_gender.lower() == 'none' else data_args.enforce_gender
        self.normalize_relpose = denoiser_args.normalize_relpose
        train_dataset = dataset_class(dataset_path=data_args.data_dir,
                                      dataset_name=data_args.dataset,
                                      cfg_path=data_args.cfg_path, prob_static=data_args.prob_static,
                                      enforce_gender=data_args.enforce_gender,
                                      enforce_zero_beta=data_args.enforce_zero_beta,
                                      body_type=data_args.body_type,
                                      split='train', device=device,
                                      weight_scheme=data_args.weight_scheme,
                                      text_tolerance=data_args.text_tolerance,
                                      mode=mode,
                                      text_sep=denoiser_args.text_sep,
                                      sep_mode=denoiser_args.sep_mode,
                                      padding=self.args.padding,
                                      use_interaction_model=self.args.use_interaction_model,
                                      normalize_relpose=denoiser_args.normalize_relpose,
                                      clip_version=denoiser_args.clip_version,
                                      load_text_embedding=denoiser_args.load_text_embedding,
                                      use_indi_text=denoiser_args.use_indi_text,
                                      )
        val_dataset = train_dataset
        # if 'text' in data_args.weight_scheme or 'samp:1' in data_args.weight_scheme:
        #     val_dataset = train_dataset
        # else:
        #     val_dataset = dataset_class(dataset_path=data_args.data_dir, dataset_name=data_args.dataset,
        #                                                    cfg_path=data_args.cfg_path, prob_static=data_args.prob_static,
        #                                                    enforce_gender=data_args.enforce_gender,
        #                                                    enforce_zero_beta=data_args.enforce_zero_beta,
        #                                                    split='val', device=device,
        #                                                    weight_scheme=data_args.weight_scheme,
        #                                                    text_tolerance=data_args.text_tolerance,
        #                                                    )
        
        # get primitive configs
        data_args.history_length = train_dataset.history_length
        data_args.future_length = train_dataset.future_length
        data_args.num_primitive = train_dataset.num_primitive
        data_args.feature_dim = 0
        data_args.padding = self.args.padding
        for k in train_dataset.motion_repr:
            data_args.feature_dim += train_dataset.motion_repr[k]

        mvae_checkpoint_dir = Path(denoiser_args.mvae_path).parent
        arg_path = mvae_checkpoint_dir / "args.yaml"
        with open(arg_path, "r") as f:
            mvae_args = tyro.extras.from_yaml(MVAEArgs, yaml.safe_load(f))

        assert mvae_args.data_args.history_length == data_args.history_length
        assert mvae_args.data_args.future_length == data_args.future_length
        # assert mvae_args.data_args.feature_dim == data_args.feature_dim

        run_name = f"{args.exp_name}__seed{args.seed}__{int(time.time())}"
        if args.track:
            import wandb
            wandb.init(
                project=args.wandb_project_name,
                entity=args.wandb_entity,
                sync_tensorboard=True,
                config=vars(args),
                name=run_name,
                save_code=True,
                # settings=wandb.Settings(code_dir="./mld"),
            )
            def include_fn(path, root):
                rel_path = os.path.relpath(path, root)
                flag = (rel_path.startswith("mld/") and len(Path(rel_path).parents) <= 2) or rel_path.startswith("model/")
                return flag
            wandb.run.log_code(root=".",
                               include_fn=include_fn
                               )
        writer = SummaryWriter(f"runs/{run_name}")
        writer.add_text(
            "hyperparameters",
            "|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])),
        )

        # load mvae model and freeze
        print('vae model args:', asdict(mvae_args.model_args))
        if mvae_args.train_args.finetune:
            vae_model = AutoMldVaeWithAdapter(
                **asdict(mvae_args.model_args),
            ).to(device)
        else:
            if mvae_args.model_args.type == "v1":
                vae_model = AutoMldVae(
                    **asdict(mvae_args.model_args),
                ).to(device)
            elif mvae_args.model_args.type == "v2":
                vae_model = AutoMldVaeV2(
                    **asdict(mvae_args.model_args),
                ).to(device)
        
        checkpoint = torch.load(denoiser_args.mvae_path, map_location=device)
        model_state_dict = checkpoint['model_state_dict']
        if 'latent_mean' not in model_state_dict:
            model_state_dict['latent_mean'] = torch.tensor(0)
        if 'latent_std' not in model_state_dict:
            model_state_dict['latent_std'] = torch.tensor(1)
        vae_model.load_state_dict(model_state_dict)
        vae_model.latent_mean = model_state_dict['latent_mean']  # register buffer seems to be not loaded by load_state_dict
        vae_model.latent_std = model_state_dict['latent_std']
        print(f"Loading vae checkpoint from {denoiser_args.mvae_path}")
        print(f"latent_mean: {vae_model.latent_mean}")
        print(f"latent_std: {vae_model.latent_std}")
        for param in vae_model.parameters():
            param.requires_grad = False
        vae_model.eval()
        
        if not denoiser_args.load_text_embedding:
            print('Loading CLIP...')
            # self.clip_device = denoiser_args.clip_device
            self.clip_device = device
            if denoiser_args.text_encoder_version == "v1":
                self.text_encoder = CLIPTextEncoder(denoiser_args.clip_version, clip_device=self.clip_device)
            elif denoiser_args.text_encoder_version == "v2":
                self.text_encoder = CLIPTextEncoderV2(denoiser_args.clip_version, clip_final_proj=denoiser_args.clip_final_proj, clip_device=self.clip_device)
            elif denoiser_args.text_encoder_version == "v3":
                self.text_encoder = CLIPTextEncoderV3(denoiser_args.clip_version, clip_final_proj=denoiser_args.clip_final_proj, clip_device=self.clip_device)
            elif denoiser_args.text_encoder_version == "v4":
                self.text_encoder = CLIPTextEncoderV4(denoiser_args.clip_version, clip_final_proj=denoiser_args.clip_final_proj, clip_device=self.clip_device)
        else:
            self.text_encoder = None

        if denoiser_args.model_type == "mlp":
            denoiser_class = DenoiserMLP
            denoiser_model_args = DenoiserMLPArgs()
        elif denoiser_args.model_type == "transformer":
            denoiser_class = DenoiserTransformer
            denoiser_model_args = DenoiserTransformerArgs()
        elif denoiser_args.model_type == "transformerv2":
            denoiser_class = DenoiserTransformerV2
            denoiser_model_args = DenoiserTransformerArgs()
        elif denoiser_args.model_type == "transformerv3":
            denoiser_class = DenoiserTransformerV3
            denoiser_model_args = DenoiserTransformerArgs()
        elif denoiser_args.model_type == "transformerv4":
            denoiser_class = DenoiserTransformerV4
            denoiser_model_args = DenoiserTransformerArgs()
        elif denoiser_args.model_type == "transformerv5":
            denoiser_class = DenoiserTransformerV5
            denoiser_model_args = DenoiserTransformerArgs()
        elif denoiser_args.model_type == "transformerv6":
            denoiser_class = DenoiserTransformerV6
            denoiser_model_args = DenoiserTransformerArgs()
        elif denoiser_args.model_type == "transformerv7":
            denoiser_class = DenoiserTransformerV7
            denoiser_model_args = DenoiserTransformerArgs()
        elif denoiser_args.model_type == "transformerv8":
            denoiser_class = DenoiserTransformerV8
            denoiser_model_args = DenoiserTransformerArgs()
        elif denoiser_args.model_type == "transformerv9":
            denoiser_class = DenoiserTransformerV9
            denoiser_model_args = DenoiserTransformerArgs()
        elif denoiser_args.model_type == "transformerv13":
            denoiser_class = DenoiserTransformerV13
            denoiser_model_args = DenoiserTransformerArgs()
        elif denoiser_args.model_type == "transformerv14":
            denoiser_class = DenoiserTransformerV14
            denoiser_model_args = DenoiserTransformerArgs()
        elif denoiser_args.model_type == "transformerv15":
            denoiser_class = DenoiserTransformerV15
            denoiser_model_args = DenoiserTransformerArgs()
        elif denoiser_args.model_type == "transformerv16":
            denoiser_class = DenoiserTransformerV16
            denoiser_model_args = DenoiserTransformerArgs()
        elif denoiser_args.model_type == "transformerv17":
            denoiser_class = DenoiserTransformerV17
            denoiser_model_args = DenoiserTransformerArgs()
        else: # isinstance(denoiser_model_args, DenoiserInterGenArgs)
            denoiser_class = DenoiserInterGen
            denoiser_model_args = DenoiserInterGenArgs()
            
        denoiser_model_args.history_shape = (data_args.history_length, data_args.feature_dim)
        denoiser_model_args.noise_shape = mvae_args.model_args.latent_dim
        
        args.denoiser_args.model_args = denoiser_model_args
        if denoiser_args.load_text_embedding:
            args.denoiser_args.model_args.clip_dim = train_dataset.dim_embed_text
        else:
            args.denoiser_args.model_args.clip_dim = self.text_encoder.clip_ln_final.normalized_shape[0]
       
        model_kwargs = asdict(denoiser_model_args)       
        model_kwargs['text_ca'] = self.args.denoiser_args.text_ca
        model_kwargs['text_sep'] = self.args.denoiser_args.text_sep
        model_kwargs['his_mask_prob'] = self.args.denoiser_args.his_mask_prob
        model_kwargs['interaction_mask_prob'] = self.args.denoiser_args.interaction_mask_prob
        model_kwargs['shared_mask'] = self.args.denoiser_args.shared_mask
        model_kwargs['use_controlnet'] = self.args.denoiser_args.use_controlnet
        model_kwargs['text_first'] = self.args.denoiser_args.text_first if hasattr(self.args.denoiser_args, 'text_first') else True
        model_kwargs['use_interaction_model'] = self.args.use_interaction_model
        model_kwargs['attention_sep'] = self.args.denoiser_args.attention_sep
        model_kwargs['use_pre_latent'] = self.args.denoiser_args.use_pre_latent
        model_kwargs['pre_max_len'] = self.args.denoiser_args.pre_max_len
        model_kwargs['use_indi_text'] = self.args.denoiser_args.use_indi_text
        model_kwargs['ffn_last'] = self.args.denoiser_args.ffn_last
        model_kwargs['init_zero'] = self.args.denoiser_args.init_zero
        model_kwargs['use_extra_pe'] = self.args.denoiser_args.use_extra_pe
        model_kwargs['use_step_pe'] = self.args.denoiser_args.use_step_pe
        model_kwargs['merge_his'] = self.args.denoiser_args.merge_his
        model_kwargs['merge_partner_his'] = self.args.denoiser_args.merge_partner_his
        
        denoiser_model = denoiser_class(**model_kwargs).to(device)
        print('denoiser model type:', denoiser_args.model_type)
        print('denoiser model args:', model_kwargs)
        
        if self.args.use_interaction_model:
            self.args.interaction_args.noise_shape = (data_args.future_length, 31)
            interaction_model = DenoiserInteraction(**asdict(self.args.interaction_args)).to(device)
        
        start_step = 1
        if args.train_args.resume_checkpoint is not None:
            checkpoint = torch.load(args.train_args.resume_checkpoint, map_location=device)
            model_state_dict = checkpoint['model_state_dict']
            denoiser_model.load_state_dict(model_state_dict)
            if self.args.use_interaction_model:
                inter_model_state_dict = checkpoint['inter_model_state_dict']
                interaction_model.load_state_dict(inter_model_state_dict)
            if not denoiser_args.load_text_embedding:
                text_encoder_state_dict = checkpoint['text_encoder_state_dict']
                self.text_encoder.load_state_dict(text_encoder_state_dict)
            start_step = checkpoint['num_steps'] + 1
            print(f"Loading checkpoint from {args.train_args.resume_checkpoint} at step {start_step}")
        
        # freeze params except for the interaction related params
        if args.denoiser_args.use_controlnet and args.freeze_single_ckpt is not None:
            checkpoint_freeze = torch.load(args.freeze_single_ckpt, map_location=device)
            missing, unexpected = denoiser_model.load_state_dict(checkpoint_freeze['model_state_dict'], strict=False)
            print(f"Loading checkpoint from {args.freeze_single_ckpt}")
            print(f"  >> Missing keys: {missing}")
            print(f"  >> Unexpected keys: {unexpected}")
            # freeze all
            for param in denoiser_model.parameters():
                param.requires_grad = False
            # unfreeze interaction related params
            self.set_requires_grad(denoiser_model, 
                                   ['control_proj', 'his_proj', 'text_proj', 'query_proj',
                                    'embed_history_b', 'embed_rel_pose', 'embed_text_inter', 
                                    'inter_ca', 'his_ca', 'relpose_ca', 'inter_text_ca'], 
                                   requires_grad=True)
            for name, param in denoiser_model.named_parameters():
                if param.requires_grad:
                    print(name)
            
        if self.args.use_interaction_model:
            denoiser_params = list(filter(lambda p: p.requires_grad, denoiser_model.parameters()))
            interaction_params = list(filter(lambda p: p.requires_grad, interaction_model.parameters()))
            extra_clip_params = []

            if not self.args.denoiser_args.load_text_embedding:
                extra_clip_params += list(filter(lambda p: p.requires_grad, self.text_encoder.parameters()))

            optimizer = optim.AdamW(
                denoiser_params + interaction_params + extra_clip_params,
                lr=train_args.learning_rate
            )

        else:
            denoiser_params = list(filter(lambda p: p.requires_grad, denoiser_model.parameters()))
            extra_clip_params = []

            if not self.args.denoiser_args.load_text_embedding:
                extra_clip_params += list(filter(lambda p: p.requires_grad, self.text_encoder.parameters()))

            optimizer = optim.AdamW(
                denoiser_params + extra_clip_params,
                lr=train_args.learning_rate
            )

        
        if args.train_args.resume_checkpoint is not None:
            old_opt_state = checkpoint['optimizer_state_dict']
            new_opt_state = optimizer.state_dict()
            new_opt_state['state'] = old_opt_state['state']
            optimizer.load_state_dict(new_opt_state)
        
        self.denoiser_model_avg = None
        if self.args.use_interaction_model:
            self.interaction_model_avg = None
        if train_args.ema_decay > 0:
            self.denoiser_model_avg = copy.deepcopy(denoiser_model)
            self.denoiser_model_avg.eval()
            if self.args.use_interaction_model:
                self.interaction_model_avg = copy.deepcopy(interaction_model)
                self.interaction_model_avg.eval()
            if not denoiser_args.load_text_embedding:
                self.text_encoder_avg = copy.deepcopy(self.text_encoder)
                self.text_encoder_avg.eval()

        with open(args.save_dir / "args.yaml", "w") as f:
            yaml.dump(tyro.extras.to_yaml(args), f)
        with open(args.save_dir / "args_read.yaml", "w") as f:
            yaml.dump(asdict(args), f)

        self.diffusion = create_gaussian_diffusion(args.denoiser_args.diffusion_args, enable_ddim=False)
        self.schedule_sampler_type = 'uniform'
        self.schedule_sampler = create_named_schedule_sampler(self.schedule_sampler_type, self.diffusion)

        self.vae_model = vae_model
        self.vae_args = mvae_args
        self.data_args = data_args
        self.denoiser_model = denoiser_model
        if self.args.use_interaction_model:
            self.interaction_model = interaction_model
        self.optimizer = optimizer
        self.writer = writer
        self.start_step = start_step
        self.train_dataset = train_dataset
        self.val_dataset = val_dataset
        self.device = device
        self.batch_size = train_args.batch_size
        self.step = start_step
        
        self.warmup_steps = args.warmup_steps
        if self.warmup_steps > 0:
            self.scheduler = CosineWarmupScheduler(
                optimizer,
                warmup=self.warmup_steps,
                max_iters=train_args.stage1_steps + train_args.stage2_steps + train_args.stage3_steps,
            )

        # self.rec_criterion = torch.nn.HuberLoss(reduction='mean', delta=1.0)
        self.rec_criterion = torch.nn.HuberLoss(reduction='none', delta=1.0)
        self.transf_rotmat = torch.eye(3, device=self.device).unsqueeze(0)
        self.transf_transl = torch.zeros(3, device=self.device).reshape(1, 1, 3)

    def set_requires_grad(self, model, keywords, requires_grad=True):
        for name, param in model.named_parameters():
            if any(kw in name for kw in keywords):
                # print(name)
                param.requires_grad = requires_grad

    def masked_loss(self, pred, target, valid_mask, mask=None):
        loss = self.rec_criterion(pred.squeeze(), target.squeeze())  # [B, ..., D]
        if mask is not None:
            epsilon = 1e-6
            masked_loss = loss * mask
            valid_counts = mask.sum(dim=list(range(1, mask.ndim))).clamp(min=epsilon)
            loss = masked_loss.sum(dim=list(range(1, masked_loss.ndim))) / valid_counts
        else:
            loss = loss.mean(dim=list(range(1, loss.ndim)))  # → [B]
        if valid_mask is not None:
            loss = loss[valid_mask]
            return loss.mean() if loss.numel() > 0 else torch.tensor(0.0, device=pred.device)
        return loss.mean()

    def foot_detect(self, feet_vel, feet_h, thres):
        velfactor, heightfactor = torch.Tensor([thres, thres, thres, thres]).to(feet_vel.device), torch.Tensor(
            [0.12, 0.05, 0.12, 0.05]).to(feet_vel.device)

        feet_x = (feet_vel[..., 0]) ** 2
        feet_y = (feet_vel[..., 1]) ** 2
        feet_z = (feet_vel[..., 2]) ** 2

        contact = (((feet_x + feet_y + feet_z) < velfactor) & (feet_h < heightfactor)).float()
        return contact
    
    def cal_geo_loss(self, pred_rot, gt_rot, valid_mask, eps=1e-7):
        B, T, *_ = gt_rot.shape
        pred_m = transforms.rotation_6d_to_matrix(pred_rot.reshape(-1,6))
        gt_m = transforms.rotation_6d_to_matrix(gt_rot.reshape(-1,6))
        m = torch.bmm(gt_m, pred_m.transpose(1, 2))
        cos = (m[..., 0, 0] + m[..., 1, 1] + m[..., 2, 2] - 1) / 2.0
        theta = torch.acos(torch.clamp(cos, -1.0+eps, 1.0-eps)).reshape(B, T, -1)
        loss = theta.mean(dim=list(range(1, theta.ndim)))
        if valid_mask is not None:
            loss = loss[valid_mask]
            if loss.numel() == 0:
                return torch.tensor(0.0, device=theta.device)
        loss = loss.mean()
        return loss
    
    def calc_loss(self, motion, cond, history_motion, future_motion_gt, future_motion_pred, latent_gt, latent_pred, weights, timestep_mask, 
                  inter_start=None, inter_pred=None, noise_inter=None, noise_inter_pred=None,
                  noise=None, noise_pred=None):
        train_args = self.args.train_args
        model_kwargs = cond
        future_length = self.train_dataset.future_length
        history_length = self.train_dataset.history_length
        num_primitive = self.train_dataset.num_primitive

        primitive_padding_mask = cond['y'].get('primitive_padding_mask', None)
        if primitive_padding_mask is not None:
            padding_mask = primitive_padding_mask[:,-1].view(-1)
            valid_mask = ~padding_mask
        else:
            valid_mask = None
        
        if valid_mask is not None and valid_mask.sum() == 0:
            return None
        
        terms = {}
        # feature reconstruction loss
        # feature_rec_loss = self.rec_criterion(future_motion_pred, future_motion_gt)
        feature_rec_loss = self.masked_loss(future_motion_pred, future_motion_gt, valid_mask)
        terms['feature_rec'] = feature_rec_loss

        # latent rec loss
        # latent_rec_loss = self.rec_criterion(latent_pred, latent_gt)
        latent_rec_loss = self.masked_loss(latent_pred, latent_gt, valid_mask)
        terms['latent_rec'] = latent_rec_loss
        
        # noise rec loss
        if not self.args.denoiser_args.diffusion_args.predict_xstart and noise is not None and noise_pred is not None:
            noise_rec_loss = self.masked_loss(noise_pred, noise, valid_mask)
            terms['noise_rec'] = noise_rec_loss
        else:
            terms['noise_rec'] = torch.tensor(0.0, device=self.device)
        
        if self.args.use_interaction_model:
            # inter_pred = inter_pred.permute(1, 0, 2)  # [T=1, B, D]
            # inter_rec_loss = self.rec_criterion(inter_pred, inter_start)
            inter_rec_loss = self.masked_loss(inter_pred, inter_start, valid_mask)
            terms['inter_rec'] = inter_rec_loss
            if not self.args.denoiser_args.diffusion_args.predict_xstart and noise_inter is not None and noise_inter_pred is not None:
                noise_inter_rec_loss = self.masked_loss(noise_inter_pred, noise_inter, valid_mask)
                terms['noise_inter_rec'] = noise_inter_rec_loss
            else:
                terms['noise_inter_rec'] = torch.tensor(0.0, device=self.device)
        
        if train_args.weight_vel > 0 and train_args.weight_fc > 0 and train_args.weight_bl > 0 and train_args.weight_geo > 0:
            B, T, *_ = future_motion_gt.shape
            if 'joints' in self.train_dataset.motion_repr:
                pred_joints = self.train_dataset.primitive_utility.tensor_to_dict(self.train_dataset.denormalize(future_motion_pred))['joints'].reshape(B, T, -1, 3)
                gt_joints = self.train_dataset.primitive_utility.tensor_to_dict(self.train_dataset.denormalize(future_motion_gt))['joints'].reshape(B, T, -1, 3) 
            else:
                feature_dict_pred = self.train_dataset.primitive_utility.tensor_to_dict(self.train_dataset.denormalize(future_motion_pred))
                feature_dict_gt = self.train_dataset.primitive_utility.tensor_to_dict(self.train_dataset.denormalize(future_motion_gt))
                update_dict = {
                    'gender': cond['y']['gender'],
                    'betas': cond['y']['betas'][:, history_length:],
                }
                feature_dict_pred.update(update_dict)
                feature_dict_gt.update(update_dict)
                for key in feature_dict_pred.keys():
                    try:
                        if feature_dict_pred[key].shape[0] == B and feature_dict_pred[key].shape[1] == T:
                            original_shape = feature_dict_pred[key].shape
                            feature_dict_pred[key] = feature_dict_pred[key].reshape(B * T, *original_shape[2:])
                    except:
                        continue
                for key in feature_dict_gt.keys():
                    try:
                        if feature_dict_gt[key].shape[0] == B and feature_dict_gt[key].shape[1] == T:
                            original_shape = feature_dict_gt[key].shape
                            feature_dict_gt[key] = feature_dict_gt[key].reshape(B * T, *original_shape[2:])
                    except:
                        continue
                pred_joints = self.train_dataset.primitive_utility.calc_joints_from_features(feature_dict_pred).reshape(B, T, -1, 3)
                gt_joints = self.train_dataset.primitive_utility.calc_joints_from_features(feature_dict_gt).reshape(B, T, -1, 3)

            # vel loss
            pred_vel = pred_joints[:, 1:] - pred_joints[:, :-1]
            tgt_vel = gt_joints[:, 1:] - gt_joints[:, :-1]
            pred_vel = pred_vel.reshape(pred_vel.shape[:-2] + (-1,))
            tgt_vel = tgt_vel.reshape(tgt_vel.shape[:-2] + (-1,))
            terms['vel_loss'] = self.masked_loss(pred_vel, tgt_vel, valid_mask, timestep_mask.unsqueeze(1).unsqueeze(2).expand_as(pred_vel))
            
            # fc loss
            feet_vel = pred_joints[:, 1:, self.fids, :] - pred_joints[:, :-1, self.fids,:] # [B, T-1, 4, 3]
            feet_h = pred_joints[:, :-1, self.fids, 1]                                     # [B, T-1, 4, 1]
            gt_feet_vel = gt_joints[:, 1:, self.fids, :] - gt_joints[:, :-1, self.fids,:]  # [B, T-1, 4, 3]
            gt_feet_h = gt_joints[:, :-1, self.fids, 1]                                    # [B, T-1, 4, 1]
            contact_gt = self.foot_detect(gt_feet_vel, gt_feet_h, self.feet_thre)          # [B, T-1, 4]
            contact = self.foot_detect(feet_vel, feet_h, self.feet_thre)                   # [B, T-1, 4]
            terms['fc_loss'] = self.masked_loss(feet_vel*contact_gt.unsqueeze(-1), torch.zeros_like(feet_vel), valid_mask, timestep_mask.unsqueeze(1).unsqueeze(2).unsqueeze(3).expand_as(feet_vel))
            
            # bl loss
            pred_bones, gt_bones = [], []
            for chain in kinematic_chain:
                for i, joint in enumerate(chain[:-1]):
                    pred_bone = (pred_joints[..., chain[i], :] - pred_joints[..., chain[i + 1], :]).norm(dim=-1, keepdim=True)  # [B,T,1]
                    gt_bone = (gt_joints[..., chain[i], :] - gt_joints[..., chain[i + 1], :]).norm(dim=-1, keepdim=True)
                    pred_bones.append(pred_bone)
                    gt_bones.append(gt_bone)
            pred_bones = torch.cat(pred_bones, dim=-1)
            gt_bones = torch.cat(gt_bones, dim=-1)
            terms['bl_loss'] = self.masked_loss(pred_bones, gt_bones, valid_mask, timestep_mask.unsqueeze(1).unsqueeze(2).expand_as(pred_bones))
            
            # geo loss
            pred_rot = self.train_dataset.primitive_utility.tensor_to_dict(self.train_dataset.denormalize(future_motion_pred))['body_pose'].reshape(B, T, -1, 6)
            gt_rot = self.train_dataset.primitive_utility.tensor_to_dict(self.train_dataset.denormalize(future_motion_gt))['body_pose'].reshape(B, T, -1, 6)
            terms['geo_loss'] = self.cal_geo_loss(pred_rot, gt_rot, (valid_mask*timestep_mask).bool())
        
        loss = train_args.weight_latent_rec * latent_rec_loss \
            + train_args.weight_feature_rec * feature_rec_loss
        if train_args.weight_vel > 0 and train_args.weight_fc > 0 and train_args.weight_bl > 0 and train_args.weight_geo > 0:
            loss = loss \
                + train_args.weight_vel * terms['vel_loss'] \
                + train_args.weight_fc * terms['fc_loss'] \
                + train_args.weight_bl * terms['bl_loss'] \
                + train_args.weight_geo * terms['geo_loss']

        if self.args.use_interaction_model:
            loss += train_args.weight_inter_rec * (train_args.weight_latent_rec*terms['inter_rec']+(1.0-train_args.weight_latent_rec)*terms['noise_inter_rec'])
        if not self.args.denoiser_args.diffusion_args.predict_xstart and noise is not None and noise_pred is not None:
            loss += (1.0 - train_args.weight_latent_rec) * terms['noise_rec']
        terms['loss'] = loss
        return terms

    def calc_interaction_loss(self, cond, future_motion_gt, future_motion_pred, timestep_mask, dm_thre=1.0, ja_thre=0.1, 
                  noise=None, noise_pred=None):
        train_args = self.args.train_args
        future_length = self.train_dataset.future_length
        B, T, *_ = future_motion_gt['person1'].shape

        primitive_padding_mask = cond['person1']['y'].get('primitive_padding_mask', None)
        if primitive_padding_mask is not None:
            padding_mask = primitive_padding_mask[:,-1].view(-1)
            valid_mask = ~padding_mask
        else:
            valid_mask = None
            
        if valid_mask is not None and valid_mask.sum() == 0:
            return None
            
        # calculate joints position in the world frame
        gt_future_motion_dict, pred_future_motion_dict = {}, {}
        gt_future_primitive_dict, pred_future_primitive_dict = {}, {}
        gt_joints, pred_joints = {}, {}
        with torch.no_grad():
            for person in ['person1', 'person2']:
                gt_future_motion_dict[person] = self.train_dataset.tensor_to_dict(self.train_dataset.denormalize(future_motion_gt[person]))
                pelvis_delta = self.train_dataset.primitive_utility.calc_calibrate_offset({
                    'betas': cond[person]['y']['betas'][:, 0, :],
                    'gender': cond[person]['y']['gender'],
                })
                gt_future_motion_dict[person].update({
                    'transf_rotmat': cond[person]['y']['transf_rotmat'],
                    'transf_transl': cond[person]['y']['transf_transl'],
                    'gender': cond[person]['y']['gender'],
                    'betas': cond[person]['y']['betas'][:, :future_length, :],
                    'pelvis_delta': pelvis_delta,
                })
                gt_future_primitive_dict[person] = self.train_dataset.primitive_utility.feature_dict_to_smpl_dict(gt_future_motion_dict[person])
                gt_future_primitive_dict[person] = self.train_dataset.primitive_utility.transform_primitive_to_world(gt_future_primitive_dict[person])
                gt_dict = copy.deepcopy(gt_future_primitive_dict[person])
                for k, v in list(gt_dict.items()):
                    if torch.is_tensor(v) and v.ndim >= 2 and v.shape[0] == B and v.shape[1] == T:
                        gt_dict[k] = v.reshape(B * T, *v.shape[2:])
                gt_joints[person] = self.train_dataset.primitive_utility.calc_joints_from_features(gt_dict, use_predicted_joints=True).reshape(B, T, -1, 3)
                
        for person in ['person1', 'person2']:
            pred_future_motion_dict[person] = self.train_dataset.tensor_to_dict(self.train_dataset.denormalize(future_motion_pred[person]))
            pelvis_delta = self.train_dataset.primitive_utility.calc_calibrate_offset({
                'betas': cond[person]['y']['betas'][:, 0, :],
                'gender': cond[person]['y']['gender'],
            })
            pred_future_motion_dict[person].update({
                'transf_rotmat': cond[person]['y']['transf_rotmat'],
                'transf_transl': cond[person]['y']['transf_transl'],
                'gender': cond[person]['y']['gender'],
                'betas': cond[person]['y']['betas'][:, :future_length, :],
                'pelvis_delta': pelvis_delta,
            })
            pred_future_primitive_dict[person] = self.train_dataset.primitive_utility.feature_dict_to_smpl_dict(pred_future_motion_dict[person])
            pred_future_primitive_dict[person] = self.train_dataset.primitive_utility.transform_primitive_to_world(pred_future_primitive_dict[person])
            src = pred_future_primitive_dict[person]
            pred_dict = dict(src) 
            for k, v in list(pred_dict.items()):
                if torch.is_tensor(v) and v.ndim >= 2 and v.shape[0] == B and v.shape[1] == T:
                    pred_dict[k] = v.reshape(B * T, *v.shape[2:])
            pred_joints[person] = self.train_dataset.primitive_utility.calc_joints_from_features(pred_dict, use_predicted_joints=True).reshape(B, T, -1, 3)
        
        terms = {}
        # relative orientation loss
        if self.train_dataset.primitive_utility.feature_dim == 262:
            gt_rel_global_orient, _ = cal_rel_rot_from_joints(gt_future_primitive_dict['person1']['joints'].reshape(B, T, 22, 3), 
                                                gt_future_primitive_dict['person2']['joints'].reshape(B, T, 22, 3))
            pred_rel_global_orient, _ = cal_rel_rot_from_joints(pred_future_primitive_dict['person1']['joints'].reshape(B, T, 22, 3),
                                                    pred_future_primitive_dict['person2']['joints'].reshape(B, T, 22, 3))
        else:
            gt_rel_global_orient, _ = cal_rel_rot_from_matrices(gt_future_primitive_dict['person1']['global_orient'], 
                                                gt_future_primitive_dict['person2']['global_orient'])
            pred_rel_global_orient, _ = cal_rel_rot_from_matrices(pred_future_primitive_dict['person1']['global_orient'],
                                                    pred_future_primitive_dict['person2']['global_orient'])
        gt_rel_global_orient_6d = transforms.matrix_to_rotation_6d(gt_rel_global_orient)
        pred_rel_global_orient_6d = transforms.matrix_to_rotation_6d(pred_rel_global_orient)
        rel_orient_loss = self.masked_loss(gt_rel_global_orient_6d, 
                                           pred_rel_global_orient_6d, 
                                           valid_mask,
                                           timestep_mask.unsqueeze(1).unsqueeze(2).expand_as(gt_rel_global_orient_6d))
        terms['rel_orient'] = rel_orient_loss

        # joint_affinity loss and distance_map loss
        if 'joints' in gt_future_primitive_dict['person1']:
            gt_distance_matrix = torch.cdist(gt_future_primitive_dict['person1']['joints'].reshape(B, T, 22, 3).contiguous(), 
                                            gt_future_primitive_dict['person2']['joints'].reshape(B, T, 22, 3))
            pred_distance_matrix = torch.cdist(pred_future_primitive_dict['person1']['joints'].reshape(B, T, 22, 3).contiguous(),
                                            pred_future_primitive_dict['person2']['joints'].reshape(B, T, 22, 3))
        else:
            gt_distance_matrix = torch.cdist(gt_joints['person1'].reshape(B, T, -1, 3).contiguous(), 
                                            gt_joints['person2'].reshape(B, T, -1, 3))
            pred_distance_matrix = torch.cdist(pred_joints['person1'].reshape(B, T, -1, 3).contiguous(),
                                            pred_joints['person2'].reshape(B, T, -1, 3))
        ja_matrix_mask = (gt_distance_matrix < torch.tensor(ja_thre, device=gt_distance_matrix.device)).float()
        ja_matrix_mask = ja_matrix_mask * timestep_mask[:, None, None, None]
        joint_affinity_loss = self.masked_loss(pred_distance_matrix, torch.zeros_like(gt_distance_matrix), valid_mask, ja_matrix_mask)
        terms['joint_affinity'] = joint_affinity_loss
        # distance_map loss
        dm_matrix_mask = (pred_distance_matrix < torch.tensor(dm_thre, device=pred_distance_matrix.device)).float()
        dm_matrix_mask = dm_matrix_mask * timestep_mask[:, None, None, None]
        distance_map_loss = self.masked_loss(gt_distance_matrix, pred_distance_matrix, valid_mask, dm_matrix_mask)
        terms['distance_map'] = distance_map_loss
        
        loss = train_args.weight_rel_orient * rel_orient_loss + \
               train_args.weight_joint_affinity * joint_affinity_loss + \
               train_args.weight_distance_map * distance_map_loss
        terms['loss'] = loss
        return terms

    def common_step(self, motion, cond, last_primitive, interaction=None, last_transf_rotmat=None, last_transf_transl=None, pre_latent=None, pre_transf_rotmat_abs=None, pre_transf_transl_abs=None, react_gen=False, use_rollout=True):
        denoiser_args = self.args.denoiser_args
        train_args = self.args.train_args
        future_length = self.train_dataset.future_length
        history_length = self.train_dataset.history_length
        num_primitive = self.train_dataset.num_primitive

        if self.args.data_args.interaction:
            future_motion_gt, history_motion_gt, history_motion, latent_gt, transf_rotmat, transf_transl = {}, {}, {}, {}, {}, {}
            person_list = ['person1', 'person2']
            if react_gen:
                actor = 'person1' if random.random() < 0.5 else 'person2'
            else:
                actor = ''
            for person in person_list:
                motion_tensor = motion[person].squeeze(2).permute(0, 2, 1)  # [B, T, D]
                future_motion_gt[person] = motion_tensor[:, -future_length:, :]
                history_motion_gt[person] = motion_tensor[:, :history_length, :]
                if last_primitive[person] is not None and (not react_gen or person != actor):
                    rollout_history, transf_rotmat[person], transf_transl[person] = self.get_rollout_history(last_primitive[person], cond[person], return_transform=True,
                                                                                                             transf_rotmat=last_transf_rotmat[person], transf_transl=last_transf_transl[person])
                    history_motion[person] = rollout_history  # [B, H, D]
                    cond[person]['y']['transf_rotmat'] = transf_rotmat[person]
                    cond[person]['y']['transf_transl'] = transf_transl[person]
                else:
                    history_motion[person] = history_motion_gt[person]
                
                latent_gt[person], _ = self.vae_model.encode(future_motion=future_motion_gt[person],
                                                    history_motion=history_motion_gt[person] if denoiser_args.train_rollout_history == "gt" else history_motion[person],
                                                    scale_latent=denoiser_args.rescale_latent)  # [T=1, B, D]
                # print('latent_gt:', latent_gt)
                # pdb.set_trace()
            
            if last_primitive[person] is not None:
                # compute relative pose
                rel_rotmat, rel_transl, rel_pose = {}, {}, {} 
                rel_rotmat['b2a'], rel_transl['b2a'] = self.train_dataset.primitive_utility.compute_rel_transform_B_in_A(
                    cond['person1']['y']['transf_rotmat'], cond['person1']['y']['transf_transl'], cond['person2']['y']['transf_rotmat'], cond['person2']['y']['transf_transl'])
                rel_rotmat['a2b'], rel_transl['a2b'] = self.train_dataset.primitive_utility.compute_rel_transform_B_in_A(
                    cond['person2']['y']['transf_rotmat'], cond['person2']['y']['transf_transl'], cond['person1']['y']['transf_rotmat'], cond['person1']['y']['transf_transl'])
                for rel in ['b2a', 'a2b']:
                    rel_rotmat[rel] = transforms.matrix_to_rotation_6d(rel_rotmat[rel])
                    rel_pose[rel] = torch.cat([rel_rotmat[rel], rel_transl[rel].squeeze(1)], dim=-1)  # [B*num_mp, 6+3]
                interaction['rel_pose_b2a'] = rel_pose['b2a']
                interaction['rel_pose_a2b'] = rel_pose['a2b']
                if self.normalize_relpose:
                    interaction['rel_pose_b2a'] = self.train_dataset.normalize_rel_pose(interaction['rel_pose_b2a'])
                    interaction['rel_pose_a2b'] = self.train_dataset.normalize_rel_pose(interaction['rel_pose_a2b'])
                
                if self.args.use_interaction_model:
                    rel_global_orient, rel_root_transl, rel_mindis, rel_info = {}, {}, {}, {}
                    future_frames_dict_gt = {}
                    for person in person_list:
                        future_frames_gt = self.train_dataset.denormalize(future_motion_gt[person])
                        future_frames_dict_gt[person] = self.train_dataset.primitive_utility.tensor_to_dict(future_frames_gt)
                        future_frames_dict_gt[person].update(
                            {
                                'transf_rotmat': transf_rotmat[person],
                                'transf_transl': transf_transl[person],
                                'gender': cond[person]['y']['gender'],
                                'betas': cond[person]['y']['betas'][:,-future_length:, :],
                                'pelvis_delta': self.train_dataset.primitive_utility.calc_calibrate_offset({
                                                    'betas': cond[person]['y']['betas'][:, -1, :],  # [B, 10]
                                                    'gender': cond[person]['y']['gender'],
                                                }),
                            }
                        )
                        future_frames_dict_gt[person] = self.train_dataset.primitive_utility.feature_dict_to_smpl_dict(future_frames_dict_gt[person])
                        future_frames_dict_gt[person] = self.train_dataset.primitive_utility.transform_primitive_to_world(future_frames_dict_gt[person])
                    B, T, *_ = future_frames_dict_gt['person1']['body_pose'].shape
                    if self.train_dataset.primitive_utility.feature_dim in [276, 56*6, 55*12]:
                        rel_global_orient['b2a'] = future_frames_dict_gt['person1']['global_orient'].transpose(-1,-2) @ future_frames_dict_gt['person2']['global_orient']
                        rel_global_orient['a2b'] = future_frames_dict_gt['person2']['global_orient'].transpose(-1,-2) @ future_frames_dict_gt['person1']['global_orient']
                        # rel_root_transl['b2a'] = torch.matmul(future_frames_dict_gt['person1']['global_orient'].transpose(-1, -2), 
                        #                                         (future_frames_dict_gt['person2']['transl']-future_frames_dict_gt['person1']['transl']).unsqueeze(-1)).squeeze(-1)
                        # rel_root_transl['a2b'] = torch.matmul(future_frames_dict_gt['person2']['global_orient'].transpose(-1, -2), 
                        #                                         (future_frames_dict_gt['person1']['transl']-future_frames_dict_gt['person2']['transl']).unsqueeze(-1)).squeeze(-1)
                        rel_root_transl['b2a'] = torch.matmul((future_frames_dict_gt['person2']['transl']-future_frames_dict_gt['person1']['transl']), transf_rotmat['person1'])
                        rel_root_transl['a2b'] = torch.matmul((future_frames_dict_gt['person1']['transl']-future_frames_dict_gt['person2']['transl']), transf_rotmat['person2'])
                    elif self.train_dataset.primitive_utility.feature_dim == 262:
                        rel_global_orient['b2a'], rel_global_orient['a2b'] = cal_rel_rot_from_joints(future_frames_dict_gt['person1']['joints'].reshape(B, T, 22, 3), future_frames_dict_gt['person2']['joints'].reshape(B, T, 22, 3))
                        rel_root_transl['b2a'] = torch.matmul(future_frames_dict_gt['person2']['joints'].reshape(B, T, 22, 3)[:,:,0]-future_frames_dict_gt['person1']['joints'].reshape(B, T, 22, 3)[:,:,0], transf_rotmat['person1'])
                        rel_root_transl['a2b'] = torch.matmul(future_frames_dict_gt['person1']['joints'].reshape(B, T, 22, 3)[:,:,0]-future_frames_dict_gt['person2']['joints'].reshape(B, T, 22, 3)[:,:,0], transf_rotmat['person2']) 

                    dists = torch.norm(future_frames_dict_gt['person1']['joints'].unsqueeze(3)-future_frames_dict_gt['person2']['joints'].unsqueeze(2), dim=-1)
                    rel_mindis['b2a'], _ = dists.min(dim=-1)
                    rel_mindis['a2b'], _ = dists.min(dim=-2)
                    
                    rel_info = {}
                    for key in ['b2a', 'a2b']:
                        rel_info[key] = torch.cat([transforms.matrix_to_rotation_6d(rel_global_orient[key]), rel_root_transl[key], rel_mindis[key]], dim=-1)  # [B*num_mp, T, 6+3+22]
                    interaction['rel_info_b2a'] = self.train_dataset.normalize_rel_info(rel_info['b2a'])
                    interaction['rel_info_a2b'] = self.train_dataset.normalize_rel_info(rel_info['a2b'])
            
            if self.args.denoiser_args.merge_his_relpose:
                history_motion_rel = {}
                for person in person_list:
                    if not react_gen or person == actor:
                        history_motion_rel[person] = copy.deepcopy(history_motion[person])
                        his_motion_rel_denormalized = self.train_dataset.denormalize(history_motion_rel[person])
                        his_motion_rel = self.train_dataset.primitive_utility.relative_transform_feature_tensor(
                            his_motion_rel_denormalized,
                            transforms.rotation_6d_to_matrix(interaction['rel_pose_a2b' if person == 'person1' else 'rel_pose_b2a'][:, :6]),
                            interaction['rel_pose_a2b' if person == 'person1' else 'rel_pose_b2a'][:, 6:9].unsqueeze(1),
                            cond[person]['y']['gender'],
                            cond[person]['y']['betas'][:, 0],
                        )
                        history_motion_rel[person] = self.train_dataset.normalize(his_motion_rel)  # [B, H, D]
            
            t, weights = self.schedule_sampler.sample(self.batch_size, device=self.device)  # weights always 1
            # t, weights = {}, {}
            loss_dict_sep, future_motion_pred, latent_pred = {}, {}, {}
            x_start = {}
            if self.args.use_interaction_model:
                latent_intra_block = {}
                noise_inter = {}
            noise_motion, noise_pred = {}, {}
            for person in person_list:
                if not react_gen or person != actor:
                    # forward diffusion
                    # t[person], weights[person] = self.schedule_sampler.sample(self.batch_size, device=self.device)  # weights always 1
                    x_start[person] = latent_gt[person].permute(1, 0, 2)  # [B, T=1, D]
                    noise_motion[person] = torch.randn_like(x_start[person])
                    x_t = self.diffusion.q_sample(x_start=x_start[person], t=t, noise=noise_motion[person])
                    # denoise
                    y = {
                        'history_motion_normalized': history_motion[person],
                        'history_motion_normalized_b': history_motion['person2' if person == 'person1' else 'person1'],
                        'text_inter': interaction['texts'],
                        'text_embedding_inter': interaction['text_embedding'],
                        'rel_pose': interaction['rel_pose_b2a' if person == 'person1' else 'rel_pose_a2b'],
                    }
                    if self.args.denoiser_args.use_indi_text:
                        y['text_embedding_indi'] = cond[person]['y']['text_embedding']
                    if self.args.denoiser_args.merge_his_relpose:
                        y['history_motion_normalized_b'] = history_motion_rel['person2' if person == 'person1' else 'person1']
                    if self.args.padding:
                        y['primitive_padding_mask'] = cond[person]['y']['primitive_padding_mask']
                    if self.args.denoiser_args.text_sep:
                        if self.args.denoiser_args.use_indi_text:
                            y['text_mask'] = cond[person]['y']['text_mask']
                        y['text_mask_inter'] = interaction['text_mask']
                    if self.use_extra_pe:
                        y['start_frame'] = cond[person]['y']['start_frame']
                        y['total_frames'] = cond[person]['y']['total_frames']
                    if self.args.denoiser_args.use_pre_latent:
                        y['pre_latent'] = torch.cat(pre_latent[person], dim=1) if len(pre_latent[person])!=0 else None
                        if len(pre_transf_rotmat_abs[person])==0:
                            y['pre_reltrans'] = None
                        else:
                            y['pre_reltrans'] = []
                            for transf_rotmat_abs, transf_transl_abs in zip(pre_transf_rotmat_abs[person], pre_transf_transl_abs[person]):
                                rel_rotmat, rel_transl = self.train_dataset.primitive_utility.compute_rel_transform_B_in_A(
                                    cond[person]['y']['transf_rotmat'], cond[person]['y']['transf_transl'], transf_rotmat_abs, transf_transl_abs)
                                y['pre_reltrans'].append(torch.cat([transforms.matrix_to_rotation_6d(rel_rotmat), rel_transl.squeeze(1)], dim=-1).unsqueeze(1))
                            y['pre_reltrans'] = torch.cat(y['pre_reltrans'], dim=1)  # [B, num_primitive, 6+3]

                    if self.args.denoiser_args.diffusion_args.predict_xstart:
                        if self.args.use_interaction_model:
                            x_start_pred, latent_intra_block[person] = self.denoiser_model(x_t=x_t, timesteps=self.diffusion._scale_timesteps(t), y=y)  # [B, T=1, D]
                        else:
                            x_start_pred = self.denoiser_model(x_t=x_t, timesteps=self.diffusion._scale_timesteps(t), y=y)  # [B, T=1, D]
                        noise_pred[person] = self.diffusion._predict_eps_from_xstart(x_t, t, x_start_pred)
                        latent_pred[person] = x_start_pred.permute(1, 0, 2)  # [T=1, B, D]
                        future_motion_pred[person] = self.vae_model.decode(latent_pred[person], history_motion[person], nfuture=future_length,
                                                                scale_latent=denoiser_args.rescale_latent)  # [B, F, D], normalized
                    else:
                        if self.args.use_interaction_model:
                            noise_pred[person], latent_intra_block[person] = self.denoiser_model(x_t=x_t, timesteps=self.diffusion._scale_timesteps(t), y=y)
                            x_start_pred = self.diffusion._predict_xstart_from_eps(x_t, t, noise_pred[person])
                        else:
                            noise_pred[person] = self.denoiser_model(x_t=x_t, timesteps=self.diffusion._scale_timesteps(t), y=y)
                            x_start_pred = self.diffusion._predict_xstart_from_eps(x_t, t, noise_pred[person])
                        latent_pred[person] = x_start_pred.permute(1, 0, 2)  # [T=1, B, D]
                        future_motion_pred[person] = self.vae_model.decode(latent_pred[person], history_motion[person], nfuture=future_length,
                                                                scale_latent=denoiser_args.rescale_latent)
                else:
                    future_motion_pred[person] = copy.deepcopy(future_motion_gt[person]).detach()
                    latent_pred[person] = copy.deepcopy(latent_gt[person]).detach()
                    loss_dict_sep[actor] = None  # no loss calculated for actor in react gen

            if react_gen:
                timestep_mask = (t[reactor] <= self.args.denoiser_args.diffusion_args.diffusion_steps*self.args.scale_interloss_timestep).float()
            else:
                timestep_mask = (t <= self.args.denoiser_args.diffusion_args.diffusion_steps*self.args.scale_interloss_timestep).float()
            
            for person in person_list:
                if not react_gen or person != actor:
                    if self.args.use_interaction_model:
                        inter_start = interaction['rel_info_b2a' if person == 'person1' else 'rel_info_a2b']
                        noise_inter = torch.randn_like(inter_start)
                        inter_t = self.diffusion.q_sample(x_start=inter_start, t=t, noise=noise_inter)
                        y = {
                            'latent_a': latent_intra_block[person],
                            'latent_b': latent_intra_block['person2' if person == 'person1' else 'person1'],
                        }
                        if self.args.padding:
                            y['primitive_padding_mask'] = cond[person]['y']['primitive_padding_mask'][:,-1]
                        if self.args.denoiser_args.diffusion_args.predict_xstart:
                            inter_pred = self.interaction_model(inter_t, t, y)
                            noise_pred_inter = self.diffusion._predict_eps_from_xstart(inter_t, t, inter_pred)
                        else:
                            noise_pred_inter = self.interaction_model(inter_t, t, y)
                            inter_pred = self.diffusion._predict_xstart_from_eps(inter_t, t, noise_pred_inter)
                    
                    loss_dict_sep[person] = self.calc_loss(motion[person], 
                                                        cond[person], 
                                                        history_motion[person], 
                                                        future_motion_gt[person], 
                                                        future_motion_pred[person], 
                                                        latent_gt[person], 
                                                        latent_pred[person], 
                                                        weights,
                                                        timestep_mask,
                                                        inter_start if self.args.use_interaction_model else None,
                                                        inter_pred if self.args.use_interaction_model else None,
                                                        noise_inter=noise_inter if self.args.use_interaction_model else None,
                                                        noise_inter_pred=noise_pred_inter if self.args.use_interaction_model else None,
                                                        noise=noise_motion[person] if not self.args.denoiser_args.diffusion_args.predict_xstart else None,
                                                        noise_pred=noise_pred[person] if not self.args.denoiser_args.diffusion_args.predict_xstart else None)
            
            if loss_dict_sep['person1'] is None and loss_dict_sep['person2'] is None:
                loss_dict = None
            else:
                loss_dict = {}
                if not react_gen:
                    for key in loss_dict_sep['person1'].keys():
                        loss_dict[key] = (loss_dict_sep['person1'][key] + loss_dict_sep['person2'][key])/2.0
                else:
                    reactor = 'person2' if actor == 'person1' else 'person1'
                    for key in loss_dict_sep[reactor].keys():
                        loss_dict[key] = loss_dict_sep[reactor][key]
            
                if self.args.use_interaction_loss:
                    loss_inter = self.calc_interaction_loss(cond, future_motion_gt, future_motion_pred, timestep_mask)
                    for key in loss_inter.keys():
                        if key in loss_dict:
                            loss_dict[key] += loss_inter[key]
                        else:
                            loss_dict[key] = loss_inter[key]
                            
            if self.step > train_args.stage1_steps and self.args.denoiser_args.train_rollout_type == "full" and use_rollout:  # sample with full ddpm loop to get rollout history
                # sample_fn = self.diffusion.p_sample_loop
                sample_fn = self.diffusion.p_sample_loop if self.args.denoiser_args.diffusion_args.respacing == '' else self.diffusion.ddim_sample_loop
                for person in person_list:
                    if not react_gen or person != actor:
                        with torch.no_grad():
                            with amp.autocast(enabled=bool(train_args.use_amp), dtype=torch.float16):
                                x_start_pred = sample_fn(
                                    self.denoiser_model,
                                    x_start[person].shape,
                                    clip_denoised=False,
                                    model_kwargs={'y': y},
                                    skip_timesteps=0,  # 0 is the default value - i.e. don't skip any step
                                    init_image=x_start[person],
                                    progress=False,
                                    dump_steps=None,
                                    noise=None,
                                    const_noise=False,
                                )
                                latent_pred[person] = x_start_pred.permute(1, 0, 2)  # [T=1, B, D]
                                future_motion_pred[person] = self.vae_model.decode(latent_pred[person], history_motion[person], nfuture=future_length,
                                                                        scale_latent=denoiser_args.rescale_latent)  # [B, F, D], normalized
            
            if self.args.denoiser_args.use_pre_latent:
                if self.step > train_args.stage1_steps and self.args.denoiser_args.train_rollout_type == "full" and use_rollout:
                    return loss_dict, future_motion_pred, {person: latent_pred[person].permute(1, 0, 2).detach() for person in person_list}
                return loss_dict, future_motion_pred, {person: latent_gt[person].permute(1, 0, 2).detach() for person in person_list}
            else:
                return loss_dict, future_motion_pred
        else:
            motion_tensor = motion.squeeze(2).permute(0, 2, 1)  # [B, T, D]
            future_motion_gt = motion_tensor[:, -future_length:, :]
            history_motion_gt = motion_tensor[:, :history_length, :]
            if last_primitive is not None:
                rollout_history = self.get_rollout_history(last_primitive, cond)
                history_motion = rollout_history  # [B, H, D]
            else:
                history_motion = history_motion_gt
            latent_gt, _ = self.vae_model.encode(future_motion=future_motion_gt,
                                                history_motion=history_motion_gt if denoiser_args.train_rollout_history == "gt" else history_motion,
                                                scale_latent=denoiser_args.rescale_latent)  # [T=1, B, D]
            # print('latent_gt:', latent_gt)
            # pdb.set_trace()

            t, weights = self.schedule_sampler.sample(self.batch_size, device=self.device)  # weights always 1
            # print('t:', t, 'weights:', weights)

            # forward diffusion
            x_start = latent_gt.permute(1, 0, 2)  # [B, T=1, D]
            noise = torch.randn_like(x_start)
            x_t = self.diffusion.q_sample(x_start=x_start, t=t, noise=noise)
            # denoise
            y = {
                'text_embedding': cond['y']['text_embedding'],
                'history_motion_normalized': history_motion,
            }
            if self.args.padding:
                y['primitive_padding_mask'] = cond['y']['primitive_padding_mask']
            if self.args.denoiser_args.text_sep:
                y['text_mask'] = cond['y']['text_mask']
            if self.args.denoiser_args.diffusion_args.predict_xstart:
                x_start_pred = self.denoiser_model(x_t=x_t, timesteps=self.diffusion._scale_timesteps(t), y=y)  # [B, T=1, D]
                noise_pred = self.diffusion._predict_eps_from_xstart(x_t, t, x_start_pred)
            else:
                noise_pred = self.denoiser_model(x_t=x_t, timesteps=self.diffusion._scale_timesteps(t), y=y)
                x_start_pred = self.diffusion._predict_xstart_from_eps(x_t, t, noise_pred)
            latent_pred = x_start_pred.permute(1, 0, 2)  # [T=1, B, D]

            future_motion_pred = self.vae_model.decode(latent_pred, history_motion, nfuture=future_length,
                                                    scale_latent=denoiser_args.rescale_latent)  # [B, F, D], normalized

            loss_dict = self.calc_loss(motion, cond, history_motion, future_motion_gt, future_motion_pred, latent_gt, latent_pred, weights, 
                                       noise=noise if not self.args.denoiser_args.diffusion_args.predict_xstart else None, 
                                       noise_pred=noise_pred if not self.args.denoiser_args.diffusion_args.predict_xstart else None)

            if self.step > train_args.stage1_steps and self.args.denoiser_args.train_rollout_type == "full":  # sample with full ddpm loop to get rollout history
                # sample_fn = self.diffusion.p_sample_loop
                sample_fn = self.diffusion.p_sample_loop if self.args.denoiser_args.diffusion_args.respacing == '' else self.diffusion.ddim_sample_loop
                with torch.no_grad():
                    with amp.autocast(enabled=bool(train_args.use_amp), dtype=torch.float16):
                        x_start_pred = sample_fn(
                            self.denoiser_model,
                            x_start.shape,
                            clip_denoised=False,
                            model_kwargs={'y': y},
                            skip_timesteps=0,  # 0 is the default value - i.e. don't skip any step
                            init_image=x_start,
                            progress=False,
                            dump_steps=None,
                            noise=None,
                            const_noise=False,
                        )
                        latent_pred = x_start_pred.permute(1, 0, 2)  # [T=1, B, D]
                        # if torch.isnan(latent_pred).any() or torch.isinf(latent_pred).any():
                        #     print('latent_pred numerical error')
                        #     pdb.set_trace()
                        future_motion_pred = self.vae_model.decode(latent_pred, history_motion, nfuture=future_length,
                                                                scale_latent=denoiser_args.rescale_latent)  # [B, F, D], normalized

            return loss_dict, future_motion_pred

    def train(self):
        denoiser_model = self.denoiser_model
        if self.args.use_interaction_model:
            interaction_model = self.interaction_model
            interaction_model.train()
        optimizer = self.optimizer
        train_args = self.args.train_args
        writer = self.writer
        future_length = self.train_dataset.future_length
        history_length = self.train_dataset.history_length
        num_primitive = self.train_dataset.num_primitive

        denoiser_model.train()
        total_steps = train_args.stage1_steps + train_args.stage2_steps + train_args.stage3_steps
        rest_steps = (total_steps - self.start_step) // self.train_dataset.num_primitive + 1
        rest_steps = rest_steps * self.train_dataset.num_primitive
        progress_bar = iter(tqdm(range(rest_steps)))
        self.step = self.start_step
        
        if self.args.do_eval:
            from evaluation.eval_inter_react_mvae_v1 import evaluation_during_training, EvalArgs
            eval_file = os.path.join(self.args.save_dir, 'evaluation_training.log')
            eval_args = EvalArgs()
            eval_args.load_dir = self.args.save_dir
            eval_args.dataset = self.args.data_args.dataset
            min_fid = np.inf
            max_top1 = -np.inf

        while self.step <= total_steps:
            # Annealing the rate if instructed to do so.
            if self.step < self.warmup_steps:
                self.scheduler.step()
            else:
                if train_args.anneal_lr:
                    frac = 1.0 - (self.step - self.warmup_steps - 1.0) / (total_steps-self.warmup_steps)
                    lrnow = frac * train_args.learning_rate
                    lrnow = max(lrnow, self.min_lr)
                    optimizer.param_groups[0]["lr"] = lrnow

            react_gen = random.random() < (self.args.react_prob and self.args.data_args.interaction)
            
            with amp.autocast(enabled=bool(train_args.use_amp), dtype=torch.float16):
                batch = self.train_dataset.get_batch(self.batch_size)

            last_primitive = None if not self.args.data_args.interaction else {'person1': None, 'person2': None}
            last_transf_rotmat = None if not self.args.data_args.interaction else {'person1': None, 'person2': None}
            last_transf_transl = None if not self.args.data_args.interaction else {'person1': None, 'person2': None}
            if self.args.denoiser_args.use_pre_latent:
                pre_latent = [] if not self.args.data_args.interaction else {'person1': [], 'person2': []}
                pre_transf_rotmat_abs = [] if not self.args.data_args.interaction else {'person1': [], 'person2': []}
                pre_transf_transl_abs = [] if not self.args.data_args.interaction else {'person1': [], 'person2': []}
            for primitive_idx in range(num_primitive):
                use_rollout = False
                if self.step > train_args.stage1_steps:
                    rollout_prob = min(1.0, (self.step - train_args.stage1_steps) / max(
                        float(train_args.stage2_steps), 1e-6))
                    if torch.rand(1).item() < rollout_prob:
                        use_rollout = True
                with amp.autocast(enabled=bool(train_args.use_amp), dtype=torch.float16):
                    if self.args.data_args.interaction:
                        person_list = ['person1', 'person2']
                        motion, cond = {}, {}
                        for person in person_list:
                            motion[person], cond[person] = self.get_primitive_batch(batch[person], primitive_idx)
                        interaction = batch['interaction'][primitive_idx]
                        # text embedding
                        if not self.args.denoiser_args.load_text_embedding:
                            if self.args.denoiser_args.text_encoder_version == 'v1':
                                interaction['text_embedding'] = self.text_encoder(interaction['texts'])
                                if self.args.denoiser_args.use_indi_text:
                                    for person in person_list:
                                        cond[person]['y']['text_embedding'] = self.text_encoder(cond[person]['y']['text'])
                            elif self.args.denoiser_args.text_encoder_version in ['v2', 'v3', 'v4']:
                                interaction['text_embedding'], interaction['text_mask'] = self.text_encoder(interaction['texts'])
                                if self.args.denoiser_args.use_indi_text:
                                    for person in person_list:
                                        cond[person]['y']['text_embedding'], cond[person]['y']['text_mask'] = self.text_encoder(cond[person]['y']['text'])
                        if self.args.denoiser_args.use_pre_latent:
                            loss_dict, future_motion_pred, last_latent = \
                                self.common_step(motion, cond, last_primitive, interaction,
                                                last_transf_rotmat=last_transf_rotmat,
                                                last_transf_transl=last_transf_transl,
                                                pre_latent=pre_latent,
                                                pre_transf_rotmat_abs=pre_transf_rotmat_abs,
                                                pre_transf_transl_abs=pre_transf_transl_abs,
                                                react_gen=react_gen,
                                                use_rollout=use_rollout)
                            for person in person_list:
                                pre_latent[person].append(last_latent[person])
                                pre_transf_rotmat_abs[person].append(cond[person]['y']['transf_rotmat'])
                                pre_transf_transl_abs[person].append(cond[person]['y']['transf_transl'])
                        else:
                            loss_dict, future_motion_pred = self.common_step(motion, cond, last_primitive, interaction, 
                                                                            last_transf_rotmat=last_transf_rotmat,
                                                                            last_transf_transl=last_transf_transl,
                                                                            react_gen=react_gen,
                                                                            use_rollout=use_rollout)
                        if loss_dict is None:
                            break
                        loss = loss_dict['loss']
                    else:
                        motion, cond = self.get_primitive_batch(batch, primitive_idx)
                        if not self.args.denoiser_args.load_text_embedding:
                            if self.args.denoiser_args.text_encoder_version == 'v1':
                                cond['y']['text_embedding'] = self.text_encoder(cond['y']['text'])
                            elif self.args.denoiser_args.text_encoder_version in ['v2', 'v3', 'v4']:
                                cond['y']['text_embedding'], cond['y']['text_mask'] = self.text_encoder(cond['y']['text'])
                        loss_dict, future_motion_pred = self.common_step(motion, cond, last_primitive)
                        if loss_dict is None:
                            break
                        loss = loss_dict['loss']

                # optimize
                optimizer.zero_grad()
                loss.backward()
                nn.utils.clip_grad_norm_(denoiser_model.parameters(), train_args.grad_clip)
                if not self.args.denoiser_args.load_text_embedding:
                    nn.utils.clip_grad_norm_(self.text_encoder.parameters(), train_args.grad_clip)
                if self.args.use_interaction_model:
                    nn.utils.clip_grad_norm_(interaction_model.parameters(), train_args.grad_clip)
                optimizer.step()

                # update the average model using exponential moving average
                if train_args.ema_decay > 0:
                    for param, avg_param in zip(self.denoiser_model.parameters(), self.denoiser_model_avg.parameters()):
                        avg_param.data.mul_(train_args.ema_decay).add_(
                            param.data, alpha=1 - train_args.ema_decay)
                if self.args.use_interaction_model and train_args.ema_decay > 0:
                    for param, avg_param in zip(self.interaction_model.parameters(), self.interaction_model_avg.parameters()):
                        avg_param.data.mul_(train_args.ema_decay).add_(
                            param.data, alpha=1 - train_args.ema_decay)
                if not self.args.denoiser_args.load_text_embedding and train_args.ema_decay > 0:
                    for param, avg_param in zip(self.text_encoder.parameters(), self.text_encoder_avg.parameters()):
                        avg_param.data.mul_(train_args.ema_decay).add_(
                            param.data, alpha=1 - train_args.ema_decay)

                last_primitive = None if not self.args.data_args.interaction else {'person1': None, 'person2': None}
                last_transf_rotmat = None if not self.args.data_args.interaction else {'person1': None, 'person2': None}
                last_transf_transl = None if not self.args.data_args.interaction else {'person1': None, 'person2': None}
                
                if use_rollout:
                    if self.args.data_args.interaction:
                        for person in person_list:
                            last_primitive[person] = future_motion_pred[person].detach()
                            last_transf_rotmat[person] = cond[person]['y']['transf_rotmat'].detach()
                            last_transf_transl[person] = cond[person]['y']['transf_transl'].detach()
                    else:
                        last_primitive = future_motion_pred.detach()  # assume future length >= history length
                
                # for person in person_list:
                #     last_primitive[person] = future_motion_pred[person].detach()
                #     last_transf_rotmat[person] = cond[person]['y']['transf_rotmat'].detach()
                #     last_transf_transl[person] = cond[person]['y']['transf_transl'].detach()

                if self.step % train_args.log_interval == 0:
                    for key in loss_dict:
                        writer.add_scalar(f"loss/{key}", loss_dict[key].item(), self.step)
                    writer.add_scalar("charts/learning_rate", optimizer.param_groups[0]["lr"], self.step)

                if self.step % train_args.save_interval == 0 or self.step == total_steps:
                    self.save()

                if self.step % train_args.val_interval == 0 or self.step == total_steps:
                    self.validate()
                if self.args.do_eval and self.step > (train_args.stage1_steps + train_args.stage2_steps) and self.step % args.eval_every_n_steps == 0:
                    if self.args.eval_save:
                        self.save(self.args.save_dir / f"checkpoint_{self.step}.pt")
                    else:
                        if self.denoiser_model_avg is None:
                            original_mode = self.denoiser_model.training
                            self.denoiser_model.eval()
                            if not self.args.denoiser_args.load_text_embedding:
                                original_mode_textencoder = self.text_encoder.training
                                self.text_encoder.eval()
                            fid, mat, top1 = evaluation_during_training(self.args, eval_args, self.data_args, self.denoiser_model, 
                                                                        self.vae_model, self.vae_args, self.text_encoder,
                                                                        self.step, eval_file, device=self.device,)
                        else:
                            fid, mat, top1 = evaluation_during_training(self.args, eval_args, self.data_args, self.denoiser_model_avg, 
                                                                        self.vae_model, self.vae_args, self.text_encoder_avg,
                                                                        self.step, eval_file, device=self.device,)
                        if fid < min_fid:
                            min_fid = fid
                            self.save(self.args.save_dir / f"best_fid.pt")
                            print('Best FID Model So Far!~')
                        if top1 > max_top1:
                            max_top1 = top1
                            self.save(self.args.save_dir / f"best_top1.pt")
                            print('Best Top1 Model So Far!~')
                        writer.add_scalar('Test/FID', fid, self.step)
                        writer.add_scalar('Test/Matching', mat, self.step)
                        writer.add_scalar('Test/Top1', top1, self.step)
                        if self.denoiser_model_avg is None:
                            self.denoiser_model.train(original_mode)
                            if not self.args.denoiser_args.load_text_embedding:
                                self.text_encoder.train(original_mode_textencoder)

                self.step += 1
                next(progress_bar)
            # t3 = time.time()
            # print(f"get data time: {t2 - t1}, percent:{(t2 - t1) / (t3 - t1)}, step time: {t3 - t2}")

    def get_primitive_batch(self, batch, primitive_idx):
        motion = batch[primitive_idx]['motion_tensor_normalized']               # [bs, D, 1, T]
        cond = {'y': {
                    'gender': batch[primitive_idx]['gender'],
                    'betas': batch[primitive_idx]['betas'],                     # [bs, T, 10]
                    'history_motion': batch[primitive_idx]['history_motion'],   # [bs, D, 1, T]
                    'history_mask': batch[primitive_idx]['history_mask'],
                    'history_length': batch[primitive_idx]['history_length'],
                    'future_length': batch[primitive_idx]['future_length'],
                    'transf_rotmat': batch[primitive_idx]['transf_rotmat'],
                    'transf_transl': batch[primitive_idx]['transf_transl'],
                    }
                }
        if self.use_extra_pe:
            cond['y']['start_frame'] = batch[primitive_idx]['start_frame']
            cond['y']['total_frames'] = batch[primitive_idx]['total_frames']
        if self.args.denoiser_args.use_indi_text:
            cond['y']['text'] = batch[primitive_idx]['texts']
            if self.args.denoiser_args.load_text_embedding:
                cond['y']['text_embedding'] = batch[primitive_idx]['text_embedding']
            if self.args.denoiser_args.text_sep:
                cond['y']['text_mask'] = batch[primitive_idx]['text_mask']
        if self.args.padding:
            cond['y']['primitive_padding_mask'] = batch[primitive_idx]['primitive_padding_mask']
        return motion, cond

    def get_rollout_history(self, last_primitive, cond,
                            return_transform=False,
                            transf_rotmat=None, transf_transl=None
                            ):
        """update history motion seed, update global transform"""
        motion_tensor = last_primitive[:, -self.train_dataset.history_length:, :]  # [B, T, D]
        new_history_frames = self.train_dataset.denormalize(motion_tensor)
        primitive_utility = self.train_dataset.primitive_utility
        rollout_history = []
        genders = cond['y']['gender']
        new_transf_rotmat, new_transf_transl = [], []
        for gender_name in ['female', 'male', 'neutral']:
            gender_idx = [idx for idx in range(len(genders)) if genders[idx] == gender_name]
            if len(gender_idx) == 0:
                continue
            history_feature_dict = primitive_utility.tensor_to_dict(new_history_frames[gender_idx])
            history_feature_dict.update(
                {
                    'transf_rotmat': self.transf_rotmat.repeat(len(gender_idx), 1, 1) if transf_rotmat is None else transf_rotmat[gender_idx],
                    'transf_transl': self.transf_transl.repeat(len(gender_idx), 1, 1) if transf_transl is None else transf_transl[gender_idx],
                    'gender': gender_name,
                    'betas': cond['y']['betas'][gender_idx, -self.train_dataset.history_length:, :],
                }
            )
            pelvis_delta = primitive_utility.calc_calibrate_offset({
                'betas': history_feature_dict['betas'][:, 0, :],  # [B, 10]
                'gender': gender_name,
            })
            history_feature_dict['pelvis_delta'] = pelvis_delta
            use_predicted_joints = getattr(self.args.train_args, 'use_predicted_joints', False)
            canonicalized_history_primitive_dict, blended_feature_dict = primitive_utility.get_blended_feature(
                history_feature_dict, use_predicted_joints=use_predicted_joints)
            new_transf_rotmat.append(canonicalized_history_primitive_dict['transf_rotmat'])
            new_transf_transl.append(canonicalized_history_primitive_dict['transf_transl'])
            history_motion_tensor = primitive_utility.dict_to_tensor(blended_feature_dict)
            rollout_history.append(history_motion_tensor)

        rollout_history = torch.cat(rollout_history, dim=0)
        rollout_history = self.train_dataset.normalize(rollout_history)  # [B, T, D]
        # rollout_history = rollout_history.permute(0, 2, 1).unsqueeze(2)  # [B, D, 1, T_history]

        if return_transform:
            return rollout_history, torch.cat(new_transf_rotmat, dim=0), torch.cat(new_transf_transl, dim=0)
        else:
            return rollout_history

    def save(self, checkpoint_path=None):
        denoiser_model = self.denoiser_model if self.denoiser_model_avg is None else self.denoiser_model_avg
        print('save avg model:', self.denoiser_model_avg is not None)
        save_dict = {
            'num_steps': self.step,
            'model_state_dict': denoiser_model.state_dict(),
            'optimizer_state_dict': self.optimizer.state_dict(),
        }   
        
        if self.args.use_interaction_model:
            interaction_model = self.interaction_model if self.interaction_model_avg is None else self.interaction_model_avg
            save_dict['inter_model_state_dict'] = interaction_model.state_dict()
        if not self.args.denoiser_args.load_text_embedding:
            text_encoder = self.text_encoder if self.text_encoder_avg is None else self.text_encoder_avg
            save_dict['text_encoder_state_dict'] = text_encoder.state_dict()
        
        if checkpoint_path is None:
            checkpoint_path = self.args.save_dir / f"checkpoint_{self.step}.pt"
        torch.save(save_dict, checkpoint_path)
        print(f"Saved checkpoint at {checkpoint_path}")

    def validate(self):
        original_mode = self.denoiser_model.training
        self.denoiser_model.eval()
        if not self.args.denoiser_args.load_text_embedding:
            original_mode_textencoder = self.text_encoder.training
            self.text_encoder.eval()
        
        train_args = self.args.train_args
        future_length = self.train_dataset.future_length
        history_length = self.train_dataset.history_length
        num_primitive = self.train_dataset.num_primitive

        with torch.no_grad():
            losses_dict = {}
            for val_idx in tqdm(range(max(128, len(self.val_dataset) // self.batch_size))):
                batch = self.val_dataset.get_batch(self.batch_size)
                if self.args.data_args.interaction:
                    gt_interaction = copy.deepcopy(batch['interaction'])
                    sample_primitive_seq = {'person1': [], 'person2': []}
                    last_primitive = {'person1': None, 'person2': None}
                    last_transf_rotmat = {'person1': None, 'person2': None}
                    last_transf_transl = {'person1': None, 'person2': None}
                    if self.args.denoiser_args.use_pre_latent:
                        pre_latent = {'person1': [], 'person2': []}
                        pre_transf_rotmat_abs = {'person1': [], 'person2': []}
                        pre_transf_transl_abs = {'person1': [], 'person2': []}
                    person_list = ['person1', 'person2']
                    for primitive_idx in range(num_primitive):
                        motion, cond = {}, {}
                        for person in person_list:
                            motion[person], cond[person] = self.get_primitive_batch(batch[person], primitive_idx)
                        interaction = batch['interaction'][primitive_idx]
                        # text embedding
                        if not self.args.denoiser_args.load_text_embedding:
                            if self.args.denoiser_args.text_encoder_version == 'v1':
                                interaction['text_embedding'] = self.text_encoder(interaction['texts'])
                                if self.args.denoiser_args.use_indi_text:
                                    for person in person_list:
                                        cond[person]['y']['text_embedding'] = self.text_encoder(cond[person]['y']['text'])
                            elif self.args.denoiser_args.text_encoder_version in ['v2', 'v3', 'v4']:
                                interaction['text_embedding'], interaction['text_mask'] = self.text_encoder(interaction['texts'])
                                if self.args.denoiser_args.use_indi_text:
                                    for person in person_list:
                                        cond[person]['y']['text_embedding'], cond[person]['y']['text_mask'] = self.text_encoder(cond[person]['y']['text'])
                        if self.args.denoiser_args.use_pre_latent:
                            loss_dict, future_motion_pred, last_latent = \
                                self.common_step(motion, cond, last_primitive, interaction,
                                                last_transf_rotmat=last_transf_rotmat,
                                                last_transf_transl=last_transf_transl,
                                                pre_latent=pre_latent,
                                                pre_transf_rotmat_abs=pre_transf_rotmat_abs,
                                                pre_transf_transl_abs=pre_transf_transl_abs)
                            for person in person_list:
                                pre_latent[person].append(last_latent[person])
                                pre_transf_rotmat_abs[person].append(cond[person]['y']['transf_rotmat'])
                                pre_transf_transl_abs[person].append(cond[person]['y']['transf_transl'])
                        else:
                            loss_dict, future_motion_pred = self.common_step(motion, cond, last_primitive, interaction, 
                                                                             last_transf_rotmat=last_transf_rotmat,
                                                                             last_transf_transl=last_transf_transl)
                        if loss_dict is None:
                            break
    
                        if (self.step == 1 or self.step % train_args.save_interval == 0) and val_idx == 0:
                            for person in person_list:
                                motion_tensor = motion[person].squeeze(2).permute(0, 2, 1)  # [B, T, D]
                                history_motion_gt = motion_tensor[:, :history_length, :]
                                sample_primitive_seq[person].append(torch.cat([history_motion_gt, future_motion_pred[person]], dim=1))

                        for k, v in loss_dict.items():
                            if k not in losses_dict:
                                losses_dict[k] = []
                            losses_dict[k].append(v.detach())

                        if self.step > train_args.stage1_steps:
                            for person in person_list:
                                last_primitive[person] = future_motion_pred[person].detach()
                                last_transf_rotmat[person] = cond[person]['y']['transf_rotmat'].detach()
                                last_transf_transl[person] = cond[person]['y']['transf_transl'].detach()        
                        else:
                            last_primitive = {'person1': None, 'person2': None}
                            last_transf_rotmat = {'person1': None, 'person2': None}
                            last_transf_transl = {'person1': None, 'person2': None}

                    if (self.step == 1 or self.step % train_args.save_interval == 0) and val_idx == 0:
                        dataset = self.val_dataset
                        for idx in range(min(self.batch_size, 16)):
                            gt_motion_tensor, sample_motion_tensor, gt_seq, sample_seq = {}, {}, {}, {}
                            for person in person_list:
                                gender = batch[person][0]['gender'][idx]
                                betas = torch.stack([batch[person][i]['betas'][idx] for i in range(num_primitive)], dim=0)  # [num_primitive, T, 10]
                                if self.args.denoiser_args.use_indi_text:
                                    texts = [batch[person][i]['texts'][idx] for i in range(num_primitive)]
                                gt_motion_tensor[person] = torch.stack(
                                    [batch[person][i]['motion_tensor_normalized'][idx] for i in range(num_primitive)], dim=0)  # [num_primitive, D, 1, T]
                                gt_motion_tensor[person] = dataset.denormalize(
                                    gt_motion_tensor[person].squeeze(2).permute(0, 2, 1))  # [num_primitive, T, D]
                                # print('gt_motion_tensor:', gt_motion_tensor.shape)
                                # print('sample_primitive_seq:', len(sample_primitive_seq))
                                # print('sample_primitive_seq:', sample_primitive_seq[0].shape)
                                sample_motion_tensor[person] = torch.stack([sample_primitive_seq[person][i][idx] for i in range(num_primitive)], dim=0)
                                sample_motion_tensor[person] = dataset.denormalize(sample_motion_tensor[person])  # [num_primitive, T, D]
                                # rollout gt and sampled primitives
                                gt_seq[person] = self.rollout_primitive_seq(motion_tensor=gt_motion_tensor[person], gender=gender, betas=betas)
                                sample_seq[person] = self.rollout_primitive_seq(motion_tensor=sample_motion_tensor[person], gender=gender, betas=betas)
                                if self.args.denoiser_args.use_indi_text:
                                    gt_seq[person]['texts'] = texts
                                    sample_seq[person]['texts'] = texts
                            
                            inter_texts = [batch['interaction'][i]['texts'][idx] for i in range(num_primitive)]
                            gt_seq['inter_texts'] = inter_texts
                            sample_seq['inter_texts'] = inter_texts
                            gt_seq['rel_pose_b2a'] = gt_interaction[0]['rel_pose_b2a'][idx]
                            gt_seq['rel_pose_a2b'] = gt_interaction[0]['rel_pose_a2b'][idx]
                            sample_seq['rel_pose_b2a'] = batch['interaction'][0]['rel_pose_b2a'][idx]
                            sample_seq['rel_pose_a2b'] = batch['interaction'][0]['rel_pose_a2b'][idx]
                            export_dir = Path(self.args.save_dir, 'samples', str(self.step), str(idx))
                            export_dir.mkdir(parents=True, exist_ok=True)
                            
                            def tensor_to_numpy(data):
                                if isinstance(data, torch.Tensor):
                                    return data.detach().cpu().numpy()
                                elif isinstance(data, dict):
                                    return {key: tensor_to_numpy(value) for key, value in data.items()}
                                elif isinstance(data, list):
                                    return [tensor_to_numpy(item) for item in data]
                                else:
                                    return data
                            with open(export_dir / 'real.pkl', 'wb') as f:
                                pickle.dump(tensor_to_numpy(gt_seq), f)
                            with open(export_dir / 'sample.pkl', 'wb') as f:
                                pickle.dump(tensor_to_numpy(sample_seq), f)
                            
                            # with open(export_dir / 'real.pkl', 'wb') as f:
                            #     pickle.dump(gt_seq, f)
                            # with open(export_dir / 'sample.pkl', 'wb') as f:
                            #     pickle.dump(sample_seq, f)
                else:
                    sample_primitive_seq = []
                    last_primitive = None
                    for primitive_idx in range(num_primitive):
                        motion, cond = self.get_primitive_batch(batch, primitive_idx)
                        if not self.args.denoiser_args.load_text_embedding:
                            if self.args.denoiser_args.text_encoder_version == 'v1':
                                cond['y']['text_embedding'] = self.text_encoder(cond['y']['text'])
                            elif self.args.denoiser_args.text_encoder_version in ['v2', 'v3', 'v4']:
                                cond['y']['text_embedding'], cond['y']['text_mask'] = self.text_encoder(cond['y']['text'])
                        loss_dict, future_motion_pred = self.common_step(motion, cond, last_primitive)
                        
                        if loss_dict is None:
                            break

                        if (self.step == 1 or self.step % train_args.save_interval == 0) and val_idx == 0:
                            motion_tensor = motion.squeeze(2).permute(0, 2, 1)  # [B, T, D]
                            history_motion_gt = motion_tensor[:, :history_length, :]
                            sample_primitive_seq.append(torch.cat([history_motion_gt, future_motion_pred], dim=1))

                        for k, v in loss_dict.items():
                            if k not in losses_dict:
                                losses_dict[k] = []
                            losses_dict[k].append(v.detach())

                        if self.step > train_args.stage1_steps:
                            last_primitive = future_motion_pred.detach()
                        else:
                            last_primitive = None

                    if (self.step == 1 or self.step % train_args.save_interval == 0) and val_idx == 0:
                        dataset = self.val_dataset
                        for idx in range(min(self.batch_size, 16)):
                            gender = batch[0]['gender'][idx]
                            betas = torch.stack([batch[i]['betas'][idx] for i in range(num_primitive)], dim=0)  # [num_primitive, T, 10]
                            texts = [batch[i]['texts'][idx] for i in range(num_primitive)]
                            gt_motion_tensor = torch.stack(
                                [batch[i]['motion_tensor_normalized'][idx] for i in range(num_primitive)], dim=0)  # [num_primitive, D, 1, T]
                            gt_motion_tensor = dataset.denormalize(
                                gt_motion_tensor.squeeze(2).permute(0, 2, 1))  # [num_primitive, T, D]
                            # print('gt_motion_tensor:', gt_motion_tensor.shape)
                            # print('sample_primitive_seq:', len(sample_primitive_seq))
                            # print('sample_primitive_seq:', sample_primitive_seq[0].shape)
                            sample_motion_tensor = torch.stack([sample_primitive_seq[i][idx] for i in range(num_primitive)], dim=0)
                            sample_motion_tensor = dataset.denormalize(sample_motion_tensor)  # [num_primitive, T, D]
                            # rollout gt and sampled primitives
                            gt_seq = self.rollout_primitive_seq(motion_tensor=gt_motion_tensor, gender=gender, betas=betas)
                            gt_seq['texts'] = texts
                            sample_seq = self.rollout_primitive_seq(motion_tensor=sample_motion_tensor, gender=gender, betas=betas)
                            sample_seq['texts'] = texts

                            export_dir = Path(self.args.save_dir, 'samples', str(self.step), str(idx))
                            export_dir.mkdir(parents=True, exist_ok=True)
                            with open(export_dir / 'real.pkl', 'wb') as f:
                                pickle.dump(gt_seq, f)
                            with open(export_dir / 'sample.pkl', 'wb') as f:
                                pickle.dump(sample_seq, f)

        for k, v in losses_dict.items():
            losses_dict[k] = torch.stack(v).mean().item()
            self.writer.add_scalar(f"val_loss/{k}", losses_dict[k], self.step)
        self.denoiser_model.train(original_mode)
        if not self.args.denoiser_args.load_text_embedding:
            self.text_encoder.train(original_mode_textencoder)

    def rollout_primitive_seq(self, motion_tensor, gender, betas):
        """
        :param motion_tensor: denormalized motion tensor, [num_primitive, T, D]
        :return:
        """
        dataset = self.train_dataset
        num_primitive = dataset.num_primitive
        history_length, future_length = dataset.history_length, dataset.future_length
        primitive_utility = dataset.primitive_utility
        transf_rotmat = self.transf_rotmat
        transf_transl = self.transf_transl
        motion_sequences = None
        for primitive_idx in range(num_primitive):
            future_frames = motion_tensor[[primitive_idx], dataset.history_length:, :] if primitive_idx > 0 else motion_tensor[[primitive_idx], :, :]
            new_history_frames = motion_tensor[[primitive_idx], -dataset.history_length:, :]

            """transform primitive to world coordinate, prepare for serialization"""
            future_feature_dict = primitive_utility.tensor_to_dict(future_frames)
            future_feature_dict.update(
                {
                    'transf_rotmat': transf_rotmat,
                    'transf_transl': transf_transl,
                    'gender': gender,
                    'betas': betas[[primitive_idx], dataset.history_length:, :] if primitive_idx > 0 else betas[[primitive_idx], :, :],
                }
            )
            future_primitive_dict = primitive_utility.feature_dict_to_smpl_dict(future_feature_dict)
            future_primitive_dict = primitive_utility.transform_primitive_to_world(future_primitive_dict)
            if motion_sequences is None:
                motion_sequences = future_primitive_dict
            else:
                for key in ['transl', 'global_orient', 'body_pose', 'betas', 'joints']:
                    if key in motion_sequences:
                        motion_sequences[key] = torch.cat([motion_sequences[key], future_primitive_dict[key]], dim=1)  # [B, T, ...]
                    # print(key, motion_sequences[key].shape)

            """update history motion seed, update global transform"""
            history_feature_dict = primitive_utility.tensor_to_dict(new_history_frames)
            history_feature_dict.update(
                {
                    'transf_rotmat': transf_rotmat,
                    'transf_transl': transf_transl,
                    'gender': gender,
                    'betas': betas[[primitive_idx], -dataset.history_length:, :],
                }
            )
            use_predicted_joints = True if self.train_dataset.primitive_utility.feature_dim == 262 else False
            canonicalized_history_primitive_dict, blended_feature_dict = primitive_utility.get_blended_feature(
                history_feature_dict, use_predicted_joints=use_predicted_joints)
            transf_rotmat, transf_transl = canonicalized_history_primitive_dict['transf_rotmat'], canonicalized_history_primitive_dict['transf_transl']

        motion_sequences = {
            'gender': motion_sequences['gender'],
            'betas': motion_sequences['betas'][0],
            'body_pose': motion_sequences['body_pose'][0],
            'history_length': history_length,
            'future_length': future_length,
        }
        if 'joints' in motion_sequences:
            motion_sequences['joints'] = motion_sequences['joints'][0]  # [T, 22, 3]
        if 'transl' in motion_sequences:
            motion_sequences['transl'] = motion_sequences['transl'][0]
        if 'global_orient' in motion_sequences:
            motion_sequences['global_orient'] = motion_sequences['global_orient'][0]
        return motion_sequences

    def close(self):
        self.writer.close()

if __name__ == "__main__":
    torch.set_num_threads(4)
    torch.set_num_interop_threads(4)
    args = tyro.cli(MLDArgs)
    trainer = Trainer(args)
    trainer.train()
    trainer.close()
