from __future__ import annotations

import sys
sys.path.insert(0, sys.path[0]+r"/../")

import os
import pdb
import random
import time
from typing import Literal
from dataclasses import dataclass, asdict, make_dataclass

import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import tyro
import yaml
from torch.distributions.normal import Normal
from torch.utils.tensorboard import SummaryWriter
from pathlib import Path
from tqdm import tqdm
import pickle
import json
import copy

from model.mld_denoiser import DenoiserMLP, DenoiserTransformer
from model.mld_vae import AutoMldVae, AutoMldVaeV2, AutoMldVaeWithAdapter
from data_loaders.humanml.data.dataset import WeightedPrimitiveSequenceDataset, SinglePrimitiveDataset
from utils.smpl_utils import *
from utils.misc_util import encode_text, compose_texts_with_and
from pytorch3d import transforms
from diffusion import gaussian_diffusion as gd
from diffusion.respace import SpacedDiffusion, space_timesteps
from diffusion.resample import create_named_schedule_sampler

from mld.train_mvae import Args as MVAEArgs
from mld.train_mvae import DataArgs, TrainArgs
from mld.train_mld import MLDArgs, create_gaussian_diffusion, load_mld
from model.text_encoder import CLIPTextEncoder, CLIPTextEncoderV2, CLIPTextEncoderV3, CLIPTextEncoderV4

debug = 0

@dataclass
class RolloutArgs:
    seed: int = 0
    torch_deterministic: bool = True
    device: str = "cuda:1"

    save_dir = None
    dataset: str = 'interhuman'

    denoiser_checkpoint: str = './mld_denoiser/mld/checkpoint_300000.pt'
    respacing: str = ''
    
    seq_id: int = 11

    text_prompt: str = 'the two guys greet each other by waving*20'
    text_prompt_person1: str = 'A person waves back to a person with their right hand*20'
    text_prompt_person2: str = 'A person waves to others with their right hand*20'
    text_prompt_person3: str = ''
            
    batch_size: int = 4
    """batch size for rollout generation"""

    guidance_param: float = 4.0
    """classifier-free guidance parameter for diffusion sampling"""

    export_smpl: int = 0
    """if set to 1, export smplx sequences as npz files for blender visualization"""

    zero_noise: int = 0
    """if set to 1, use zero init noise for sampling"""

    use_predicted_joints: int = 1
    """if set to 1, use predicted joints from models without blending with smplx regressed joints. Setting to 1 will slightly accelerate the rollout process, while setting to 0 provides additional ensurance that the joints form valid smplx bodies"""

    fix_floor: int = 0
    """if set to one, fix the lowest joint to be always on the floor. This can help to ensure floor contact in long sequence generation. However, this is not applicable to actions requiring getting off the floor, such as jumping or climbing stairs"""

    use_gt: bool = False
    
    n_persons: int = 2

class ClassifierFreeWrapper(nn.Module):
    def __init__(self, model):
        super().__init__()
        self.model = model  # model is the actual model to run

        assert self.model.cond_mask_prob > 0, 'Cannot run a guided diffusion on a model that has not been trained with no conditions'

    def forward(self, x, timesteps, y=None):
        y['uncond'] = False
        out = self.model(x, timesteps, y)
        y_uncond = y
        y_uncond['uncond'] = True
        out_uncond = self.model(x, timesteps, y_uncond)
        # print('scale:', y['scale'])
        return out_uncond + (y['scale'] * (out - out_uncond))

def rollout(text_prompt, denoiser_args, denoiser_model, vae_args, vae_model, diffusion, dataset, rollout_args, data_args, text_encoder=None):
    device = rollout_args.device
    batch_size = rollout_args.batch_size
    future_length = dataset.future_length
    history_length = dataset.history_length
    primitive_length = history_length + future_length
    sample_fn = diffusion.p_sample_loop if rollout_args.respacing == '' else diffusion.ddim_sample_loop

    data_args.interaction = 1
    
    if data_args.interaction:
        person_list_N = []
        for i in range(rollout_args.n_persons):
            person_list_N.append(f'person{i+1}')

    if data_args.interaction:
        person_list = ['interaction']
        texts = {'interaction': []}
        print(denoiser_args.use_indi_text)
        if denoiser_args.use_indi_text:
            texts['person1'], texts['person2'] = [], []
            person_list.append('person1')
            person_list.append('person2')
        for person in person_list:
            if person == 'interaction':
                title = ''
            if ';' in text_prompt[person]:  # contain a time line of multipel actions
                num_rollout = 0
                for segment in text_prompt[person].split(';'):
                    action, num_mp = segment.split('*')
                    texts[person] = texts[person] + [action] * int(num_mp)
                    num_rollout += int(num_mp)
                    if person == 'interaction':
                        title = title + action + '. '
            else:
                action, num_rollout = text_prompt[person].split('*')
                num_rollout = int(num_rollout)
                for _ in range(num_rollout):
                    texts[person].append(action)    
    else:
        texts = []
        if rollout_args.dataset == 'babel':
            if ',' in text_prompt:  # contain a time line of multipel actions
                num_rollout = 0
                for segment in text_prompt.split(','):
                    action, num_mp = segment.split('*')
                    action = compose_texts_with_and(action.split(' and '))
                    texts = texts + [action] * int(num_mp)
                    num_rollout += int(num_mp)
            else:
                action, num_rollout = text_prompt.split('*')
                action = compose_texts_with_and(action.split(' and '))
                num_rollout = int(num_rollout)
                for _ in range(num_rollout):
                    texts.append(action)
        else:
            if ';' in text_prompt:  # contain a time line of multipel actions
                num_rollout = 0
                for segment in text_prompt.split(';'):
                    action, num_mp = segment.split('*')
                    texts = texts + [action] * int(num_mp)
                    num_rollout += int(num_mp)
            else:
                action, num_rollout = text_prompt.split('*')
                num_rollout = int(num_rollout)
                for _ in range(num_rollout):
                    texts.append(action)
        

    if not denoiser_args.load_text_embedding:
        all_text_embedding, all_text_mask = {}, {}
        if denoiser_args.text_encoder_version == 'v1':
            for person in person_list:
                all_text_embedding[person] = text_encoder(texts[person]).to(device)
        elif denoiser_args.text_encoder_version in ['v2', 'v3', 'v4']:
            for person in person_list:
                all_text_embedding[person], all_text_mask[person] = text_encoder(texts[person])
                all_text_embedding[person] = all_text_embedding[person].to(device)
    else:
        if data_args.interaction:
            all_text_embedding, all_text_mask = {}, {}
            for person in ['person1', 'person2', 'interaction']:
                if denoiser_args.text_sep:
                    encode_temp = encode_text(dataset.clip_model, texts[person], force_empty_zero=True, text_sep=denoiser_args.text_sep)
                    all_text_embedding[person] = encode_temp[0].to(dtype=torch.float32, device=device)
                    all_text_mask[person] = encode_temp[1].to(dtype=torch.bool, device=device)
                else:    
                    all_text_embedding[person] = encode_text(dataset.clip_model, texts[person], force_empty_zero=True, text_sep=denoiser_args.text_sep).to(dtype=torch.float32, device=device)
        else:
            if denoiser_args.text_sep:
                encode_temp = encode_text(dataset.clip_model, texts, force_empty_zero=True, text_sep=denoiser_args.text_sep)
                all_text_embedding = encode_temp[0].to(dtype=torch.float32, device=device)
                all_text_mask = encode_temp[1].to(dtype=torch.bool, device=device)
            else:    
                all_text_embedding = encode_text(dataset.clip_model, texts, force_empty_zero=True, text_sep=denoiser_args.text_sep).to(dtype=torch.float32, device=device)

    primitive_utility = dataset.primitive_utility
    print('body_type:', primitive_utility.body_type)

    out_path = rollout_args.save_dir
    filename = f'guidance{rollout_args.guidance_param}_seed{rollout_args.seed}'
    if data_args.interaction:
        if text_prompt['interaction'] != '':
            filename = text_prompt['interaction'][:40].replace(' ', '_').replace('.', '') + '_' + filename
    else:
        if text_prompt != '':
            filename = text_prompt[:40].replace(' ', '_').replace('.', '') + '_' + filename
    if rollout_args.respacing != '':
        filename = f'{rollout_args.respacing}_{filename}'
    if rollout_args.zero_noise:
        filename = f'zero_noise_{filename}'
    if rollout_args.use_predicted_joints:
        filename = f'use_pred_joints_{filename}'
    if rollout_args.fix_floor:
        filename = f'fixfloor_{filename}'
    out_path = out_path / filename
    out_path.mkdir(parents=True, exist_ok=True)

    batch = dataset.get_batch(batch_size=rollout_args.batch_size)
    if data_args.interaction:
        input_motions, model_kwargs, gender, betas, pelvis_delta, motion_tensor, history_motion_gt = {}, {}, {}, {}, {}, {}, {}
        for person in person_list_N:
            input_motions[person] = batch[person][0]['motion_tensor_normalized']
            model_kwargs[person] = {'y': batch[person][0]}
            del model_kwargs[person]['y']['motion_tensor_normalized']
            gender[person] = model_kwargs[person]['y']['gender'][0]
            betas[person] = model_kwargs[person]['y']['betas'][:, :primitive_length, :10].to(device)  # [B, H+F, 10]
            pelvis_delta[person] = primitive_utility.calc_calibrate_offset({
                'betas': betas[person][:, 0, :],
                'gender': gender[person],
            })
            input_motions[person] = input_motions[person].to(device)  # [B, D, 1, T]
            motion_tensor[person] = input_motions[person].squeeze(2).permute(0, 2, 1)  # [B, T, D]
            history_motion_gt[person] = motion_tensor[person][:, :history_length, :]  # [B, H, D]
        interaction = batch['interaction']
    else:
        input_motions, model_kwargs = batch[0]['motion_tensor_normalized'], {'y': batch[0]}
        del model_kwargs['y']['motion_tensor_normalized']
        gender = model_kwargs['y']['gender'][0]
        betas = model_kwargs['y']['betas'][:, :primitive_length, :].to(device)  # [B, H+F, 10]
        pelvis_delta = primitive_utility.calc_calibrate_offset({
            'betas': betas[:, 0, :],
            'gender': gender,
        })
        input_motions = input_motions.to(device)  # [B, D, 1, T]
        motion_tensor = input_motions.squeeze(2).permute(0, 2, 1)  # [B, T, D]
        history_motion_gt = motion_tensor[:, :history_length, :]  # [B, H, D]
    
    if text_prompt == '':
        rollout_args.guidance_param = 0.  # Force unconditioned generation


    motion_sequences = None if not data_args.interaction else {person: None for person in person_list_N}
    history_motion = history_motion_gt
    transf_rotmat = {
        person: batch[person][0]['transf_rotmat'].to(device) for person in person_list_N
    } if data_args.interaction else batch[0]['transf_rotmat'].to(device)
    transf_transl = {
        person: batch[person][0]['transf_transl'].to(device) for person in person_list_N
    } if data_args.interaction else batch[0]['transf_transl'].to(device)
    
    if rollout_args.fix_floor:
        if data_args.interaction:
            motion_dict = {}
            for person in person_list_N:
                motion_dict[person] = primitive_utility.tensor_to_dict(dataset.denormalize(history_motion_gt[person]))
                joints = motion_dict[person]['joints'].reshape(batch_size, history_length, 22, 3)  # [B, T, 22, 3]   
                init_floor_height = joints[:, 0, :, 2].amin(dim=-1)  # [B]
                transf_transl[person][:, :, 2] = -init_floor_height.unsqueeze(-1)
        else:
            motion_dict = primitive_utility.tensor_to_dict(dataset.denormalize(history_motion_gt))
            joints = motion_dict['joints'].reshape(batch_size, history_length, 22, 3)  # [B, T, 22, 3]
            init_floor_height = joints[:, 0, :, 2].amin(dim=-1)  # [B]
            transf_transl[:, :, 2] = -init_floor_height.unsqueeze(-1)

    if denoiser_args.use_pre_latent:
        pre_latent = [] if not data_args.interaction else {person: [] for person in person_list_N}
        pre_transf_rotmat_abs = [] if not data_args.interaction else {person: [] for person in person_list_N}
        pre_transf_transl_abs = [] if not data_args.interaction else {person: [] for person in person_list_N}

    for segment_id in tqdm(range(num_rollout)):
        if data_args.interaction:
            text_embedding, text_mask = {}, {}
            for person in person_list:
                if denoiser_args.text_sep:
                    text_embedding[person] = all_text_embedding[person][segment_id].expand(batch_size, -1, -1)  # [B, 512]
                    text_mask[person] = all_text_mask[person][segment_id].expand(batch_size, -1)  # [B, 512]
                else:
                    text_embedding[person] = all_text_embedding[person][segment_id].expand(batch_size, -1)  # [B, 512]
        else:
            if denoiser_args.text_sep:
                text_embedding = all_text_embedding[segment_id].expand(batch_size, -1, -1)  # [B, 512]
                text_mask = all_text_mask[segment_id].expand(batch_size, -1)  # [B, 512]
            else:
                text_embedding = all_text_embedding[segment_id].expand(batch_size, -1)  # [B, 512]
        guidance_param = torch.ones(batch_size, *denoiser_args.model_args.noise_shape).to(device=device) * rollout_args.guidance_param
        if data_args.interaction:
            # if getattr(rollout_args, 'n_persons', 2) >= 3:
            #     if segment_id < 8:
            #         gen_set = {'person1', 'person2'}
            #     else:
            #         gen_set = {'person2', 'person3'}
            # else:
            #     gen_set = set(person_list_N)
            gen_set = set(person_list_N)
            if denoiser_args.merge_his_relpose:
                letters = {'person1': 'a', 'person2': 'b', 'person3': 'c'}
                def _relative_history(src, dst):
                    key = f"rel_pose_{letters[src]}2{letters[dst]}"
                    rel = interaction[0][key]                               # [B, 9] = rot6d(6) | transl(3)
                    R = transforms.rotation_6d_to_matrix(rel[:, :6])        # [B, 3, 3]
                    t = rel[:, 6:9].unsqueeze(1)                            # [B, 1, 3]

                    his_m = copy.deepcopy(history_motion[src])              # [B,H,D]
                    his_denorm = dataset.denormalize(his_m)   # [B,H,D]
                    his_rel = dataset.primitive_utility.relative_transform_feature_tensor(
                        his_denorm, R, t,
                        batch[src][0]['gender'],
                        batch[src][0]['betas'][:, 0],
                    )
                    return dataset.normalize(his_rel)  # [B,H,D]

                history_motion_rel = {}

                if rollout_args.n_persons == 2:
                    p1, p2 = 'person1', 'person2'
                    history_motion_rel[p1] = _relative_history(p1, p2)
                    history_motion_rel[p2] = _relative_history(p2, p1)
                else:
                    for src in person_list_N:
                        history_motion_rel[src] = {}
                        for dst in person_list_N:
                            if src == dst:
                                continue
                            history_motion_rel[src][dst] = _relative_history(src, dst)
                
            latent_pred = {}
            for person in person_list_N:
                if person not in gen_set:
                    continue
                y = {
                    'history_motion_normalized': history_motion[person],
                    'text_inter': texts['interaction'],
                    'text_embedding_inter': text_embedding['interaction'],
                    'rel_pose': interaction[0]['rel_pose_'+'b2a' if person == 'person1' else 'rel_pose_'+'a2b'],
                    'scale': guidance_param,
                }
                if denoiser_args.use_indi_text:
                    y['text_embedding'] = text_embedding[person]
                if denoiser_args.text_sep:
                    if denoiser_args.use_indi_text:
                        y['text_mask'] = text_mask[person]
                    y['text_mask_inter'] = text_mask['interaction']

                others_gen = [p for p in gen_set if p != person]

                if len(others_gen) == 0:
                    y['history_motion_normalized_b'] = history_motion[person][:, :0, :]  # [B, 0, D]
                else:
                    if not denoiser_args.merge_his_relpose:
                        y['history_motion_normalized_b'] = torch.cat(
                            [history_motion[p] for p in others_gen], dim=1
                        )  # [B, sum(H_other_gen), D]
                    else:
                        rel_hists = []
                        for other in others_gen:
                            rel_entry = history_motion_rel[other]
                            if isinstance(rel_entry, dict):
                                rel_hists.append(rel_entry[person])     # [B, H, D]
                            else:
                                rel_hists.append(rel_entry)             # [B, H, D]
                        y['history_motion_normalized_b'] = torch.cat(rel_hists, dim=1)  # [B, sum(H_other_gen), D]

                if denoiser_args.use_pre_latent:
                    y['pre_latent'] = torch.cat(pre_latent[person], dim=1) if len(pre_latent[person]) != 0 else None
                    if len(pre_transf_rotmat_abs[person]) == 0:
                        y['pre_reltrans'] = None
                    else:
                        y['pre_reltrans'] = []
                        for transf_rotmat_abs, transf_transl_abs in zip(pre_transf_rotmat_abs[person], pre_transf_transl_abs[person]):
                            rel_rotmat, rel_transl = primitive_utility.compute_rel_transform_B_in_A(
                                transf_rotmat[person], transf_transl[person], transf_rotmat_abs, transf_transl_abs
                            )
                            y['pre_reltrans'].append(
                                torch.cat([transforms.matrix_to_rotation_6d(rel_rotmat), rel_transl.squeeze(1)], dim=-1).unsqueeze(1)
                            )
                        y['pre_reltrans'] = torch.cat(y['pre_reltrans'], dim=1)  # [B, num_primitive, 6+3]

                x_start_pred = sample_fn(
                    denoiser_model,
                    (batch_size, *denoiser_args.model_args.noise_shape),
                    clip_denoised=False,
                    model_kwargs={'y': y},
                    skip_timesteps=0,
                    init_image=None,
                    progress=False,
                    dump_steps=None,
                    noise=torch.zeros_like(guidance_param) if rollout_args.zero_noise else None,
                    const_noise=False,
                )  # [B, T=1, D]
                latent_pred[person] = x_start_pred.permute(1, 0, 2)  # [T=1, B, D]
        else:
            y = {
                'text_embedding': text_embedding,
                'history_motion_normalized': history_motion,
                'scale': guidance_param,
            }
            if denoiser_args.text_sep:
                y['text_mask'] = text_mask

            x_start_pred = sample_fn(
                denoiser_model,
                (batch_size, *denoiser_args.model_args.noise_shape),
                clip_denoised=False,
                model_kwargs={'y': y},
                skip_timesteps=0,  # 0 is the default value - i.e. don't skip any step
                init_image=None,
                progress=False,
                dump_steps=None,
                noise=torch.zeros_like(guidance_param) if rollout_args.zero_noise else None,
                const_noise=False,
            )  # [B, T=1, D]
            latent_pred = x_start_pred.permute(1, 0, 2)  # [T=1, B, D]
        
        if denoiser_args.use_pre_latent:
            if data_args.interaction:
                for person in person_list_N:
                    if person not in gen_set:
                        continue
                    pre_latent[person].append(latent_pred[person].permute(1, 0, 2))
                    pre_transf_rotmat_abs[person].append(transf_rotmat[person])
                    pre_transf_transl_abs[person].append(transf_transl[person])
            else:
                pre_latent.append(latent_pred)
                pre_transf_rotmat_abs.append(transf_rotmat)
                pre_transf_transl_abs.append(transf_transl)
        
        if data_args.interaction:
            future_motion_pred, future_frames, all_frames = {}, {}, {}
            def _frozen_future(person_key):
                last = dataset.denormalize(history_motion[person_key])[:, -1:, :]      # [B,1,D]
                return last.expand(-1, future_length, -1)                              # [B,F,D]

            for person in person_list_N:
                if person in gen_set:
                    future_motion_pred[person] = vae_model.decode(
                        latent_pred[person], history_motion[person], nfuture=future_length,
                        scale_latent=denoiser_args.rescale_latent
                    )
                    future_frames[person] = dataset.denormalize(future_motion_pred[person])
                else:
                    future_frames[person] = _frozen_future(person)

                all_frames[person] = torch.cat([dataset.denormalize(history_motion[person]), future_frames[person]], dim=1)

        else:
            future_motion_pred = vae_model.decode(latent_pred, history_motion, nfuture=future_length,
                                                  scale_latent=denoiser_args.rescale_latent)  # [B, F, D], normalized
            future_frames = dataset.denormalize(future_motion_pred)
            all_frames = torch.cat([dataset.denormalize(history_motion), future_frames], dim=1)

        """transform primitive to world coordinate, prepare for serialization"""
        if segment_id == 0:  # add init history motion
            future_frames = all_frames
        
        if data_args.interaction:
            if rollout_args.fix_floor:
                for person in person_list_N:
                    future_feature_dict[person] = primitive_utility.tensor_to_dict(future_frames[person])
                    joints = future_feature_dict[person]['joints'].reshape(batch_size, -1, 22, 3)
                    joints = torch.einsum('bij,btkj->btki', transf_rotmat[person], joints) + transf_transl[person].unsqueeze(1)
                    min_height = joints[:, :, :, 2].amin(dim=-1)  # [B, T]
                    transl_floor = torch.zeros(batch_size, joints.shape[1], 3, device=device, dtype=torch.float32)  # [B, T, 3]
                    transl_floor[:, :, 2] = - min_height
                    future_feature_dict[person]['transl'] += transl_floor
                    transl_delta_local = torch.einsum('bij,bti->btj', transf_rotmat[person], transl_floor)
                    joints += transl_delta_local.unsqueeze(2)
                    future_feature_dict[person]['joints'] = joints.reshape(batch_size, -1, 66)
                    future_frames[person] = primitive_utility.dict_to_tensor(future_feature_dict[person])
            for person in person_list_N:
                future_feature_dict = primitive_utility.tensor_to_dict(future_frames[person])
                future_feature_dict.update(
                    {
                        'transf_rotmat': transf_rotmat[person],
                        'transf_transl': transf_transl[person],
                        'gender': gender[person],
                        'betas': betas[person][:, :future_length, :] if segment_id > 0 else betas[person][:, :primitive_length, :],
                        'pelvis_delta': pelvis_delta[person],
                    }
                )
                future_primitive_dict = primitive_utility.feature_dict_to_smpl_dict(future_feature_dict)
                future_primitive_dict = primitive_utility.transform_primitive_to_world(future_primitive_dict)

                if motion_sequences[person] is None:
                    motion_sequences[person] = future_primitive_dict
                else:
                    for key in motion_sequences[person].keys():
                        if key in ['transl', 'global_orient', 'body_pose', 'betas', 'joints']:
                            motion_sequences[person][key] = torch.cat([motion_sequences[person][key], future_primitive_dict[key]], dim=1)  # [B, T, ...]

                """update history motion seed, update global transform"""
                new_history_frames = all_frames[person][:, -history_length:, :]
                history_feature_dict = primitive_utility.tensor_to_dict(new_history_frames)
                history_feature_dict.update(
                    {
                        'transf_rotmat': transf_rotmat[person],
                        'transf_transl': transf_transl[person],
                        'gender': gender[person],
                        'betas': betas[person][:, :history_length, :],
                        'pelvis_delta': pelvis_delta[person],
                    }
                )
                canonicalized_history_primitive_dict, blended_feature_dict = primitive_utility.get_blended_feature(
                    history_feature_dict, use_predicted_joints=rollout_args.use_predicted_joints)
                transf_rotmat[person], transf_transl[person] = canonicalized_history_primitive_dict['transf_rotmat'], \
                canonicalized_history_primitive_dict['transf_transl']
                history_motion[person] = primitive_utility.dict_to_tensor(blended_feature_dict)
                history_motion[person] = dataset.normalize(history_motion[person])  # [B, T, D]
            letters_fixed = {'person1': 'a', 'person2': 'b', 'person3': 'c'}
            letters = {}
            for i, p in enumerate(person_list_N):
                letters[p] = letters_fixed.get(p, chr(ord('a') + i))

            for src in person_list_N:
                for dst in person_list_N:
                    if src == dst:
                        continue
                    
                    rel_R, rel_t = primitive_utility.compute_rel_transform_B_in_A(
                        transf_rotmat[dst], transf_transl[dst],
                        transf_rotmat[src], transf_transl[src]
                    )
                    rel_6d = transforms.matrix_to_rotation_6d(rel_R)
                    rel_vec = torch.cat([rel_6d, rel_t.squeeze(1)], dim=-1)  # [B, 6+3]

                    key = f"rel_pose_{letters[src]}2{letters[dst]}"          
                    if denoiser_args.normalize_relpose:
                        interaction[0][key] = dataset.normalize_rel_pose(rel_vec)
                    else:
                        interaction[0][key] = rel_vec
        else:
            if rollout_args.fix_floor:
                future_feature_dict = primitive_utility.tensor_to_dict(future_frames)
                joints = future_feature_dict['joints'].reshape(batch_size, -1, 22, 3)  # [B, T, 22, 3]
                joints = torch.einsum('bij,btkj->btki', transf_rotmat, joints) + transf_transl.unsqueeze(1)
                min_height = joints[:, :, :, 2].amin(dim=-1)  # [B, T]
                transl_floor = torch.zeros(batch_size, joints.shape[1], 3, device=device, dtype=torch.float32)  # [B, T, 3]
                transl_floor[:, :, 2] = - min_height
                future_feature_dict['transl'] += transl_floor
                transl_delta_local = torch.einsum('bij,bti->btj', transf_rotmat, transl_floor)
                joints += transl_delta_local.unsqueeze(2)
                future_feature_dict['joints'] = joints.reshape(batch_size, -1, 66)
                future_frames = primitive_utility.dict_to_tensor(future_feature_dict)
            future_feature_dict = primitive_utility.tensor_to_dict(future_frames)
            future_feature_dict.update(
                {
                    'transf_rotmat': transf_rotmat,
                    'transf_transl': transf_transl,
                    'gender': gender,
                    'betas': betas[:, :future_length, :] if segment_id > 0 else betas[:, :primitive_length, :],
                    'pelvis_delta': pelvis_delta,
                }
            )
            future_primitive_dict = primitive_utility.feature_dict_to_smpl_dict(future_feature_dict)
            future_primitive_dict = primitive_utility.transform_primitive_to_world(future_primitive_dict)


            if motion_sequences is None:
                motion_sequences = future_primitive_dict
            else:
                for key in motion_sequences.keys():
                    if key in ['transl', 'global_orient', 'body_pose', 'betas', 'joints']:
                        motion_sequences[key] = torch.cat([motion_sequences[key], future_primitive_dict[key]], dim=1)  # [B, T, ...]

            """update history motion seed, update global transform"""
            new_history_frames = all_frames[:, -history_length:, :]
            history_feature_dict = primitive_utility.tensor_to_dict(new_history_frames)
            history_feature_dict.update(
                {
                    'transf_rotmat': transf_rotmat,
                    'transf_transl': transf_transl,
                    'gender': gender,
                    'betas': betas[:, :history_length, :],
                    'pelvis_delta': pelvis_delta,
                }
            )
            canonicalized_history_primitive_dict, blended_feature_dict = primitive_utility.get_blended_feature(
                history_feature_dict, use_predicted_joints=rollout_args.use_predicted_joints)
            transf_rotmat, transf_transl = canonicalized_history_primitive_dict['transf_rotmat'], \
            canonicalized_history_primitive_dict['transf_transl']
            history_motion = primitive_utility.dict_to_tensor(blended_feature_dict)
            history_motion = dataset.normalize(history_motion)  # [B, T, D]

    if data_args.interaction:
        person_keys = sorted([k for k in motion_sequences.keys() if k.startswith('person')],
                             key=lambda x: int(x.replace('person', '')))
        num_persons = len(person_keys)

        motion_sequences_temp = {}
        motion_sequences_temp['gender'] = [motion_sequences[p]['gender'] for p in person_keys]

        merge_keys = ['betas', 'transl', 'global_orient', 'body_pose', 'joints']
        for key in motion_sequences[person_keys[0]].keys():
            if key in merge_keys:
                motion_sequences_temp[key] = torch.cat(
                    [motion_sequences[p][key] for p in person_keys], dim=-1
                )
        motion_sequences = motion_sequences_temp

    if not os.path.exists(out_path):
        os.makedirs(out_path)

    def _to_rotmat(flat, per_dim, num_joints):
        each = per_dim // num_joints 
        x = flat
        if each == 9:
            return x.reshape(*x.shape[:-1], num_joints, 3, 3)
        elif each == 6:
            x = x.reshape(*x.shape[:-1], num_joints, 6)
            rm_list = []
            for j in range(num_joints):
                rm_list.append(transforms.rotation_6d_to_matrix(x[..., j, :]))  # [..., 3, 3]
            return torch.stack(rm_list, dim=-3)  # [..., num_joints, 3, 3]
        elif each == 3:
            x = x.reshape(*x.shape[:-1], num_joints, 3)
            rm_list = []
            for j in range(num_joints):
                rm_list.append(transforms.axis_angle_to_matrix(x[..., j, :]))  # [..., 3, 3]
            return torch.stack(rm_list, dim=-3)  # [..., num_joints, 3, 3]

    for idx in range(rollout_args.batch_size):
        sequence = {
            'texts': texts,
            'gender': motion_sequences['gender'],               
            'betas': motion_sequences['betas'][idx],           # [T, 10*num_persons] or [*, 10*num_persons]
            'body_pose': motion_sequences['body_pose'][idx],   # [T, body_pose_dim_total]
            'joints': motion_sequences['joints'][idx],         # [T, joints_dim_total]
            'history_length': history_length,
            'future_length': future_length,
        }
        if 'transl' in motion_sequences:
            sequence['transl'] = motion_sequences['transl'][idx]
        if 'global_orient' in motion_sequences:
            sequence['global_orient'] = motion_sequences['global_orient'][idx]

        tensor_dict_to_device(sequence, 'cpu')
        with open(out_path / f'sample_{idx}.pkl', 'wb') as f:
            pickle.dump(sequence, f)
        try:
            from visualize.vis_keypoints import plot_t2m
            num_persons = len(sequence['gender'])
            if 'joints' in sequence and sequence['joints'].ndim >= 2:
                perJ = sequence['joints'].shape[-1] // num_persons
                joints_list = [sequence['joints'][..., i*perJ:(i+1)*perJ] for i in range(num_persons)]
                suffix = '_gt' if rollout_args.use_gt else ''
                file_saved = f"test_{(rollout_args.denoiser_checkpoint).split('/')[2]}_{rollout_args.seq_id}{suffix}_sample{idx}.mp4"
                plot_t2m(joints_list, file_saved, title, radius=6, vertical='y')
                print(file_saved)
        except Exception as e:
            print(f"Error: {e}")

        if rollout_args.export_smpl:
            num_persons = len(sequence['gender'])
            go_total = sequence['global_orient'].shape[-1] if 'global_orient' in sequence else None
            bp_total = sequence['body_pose'].shape[-1]
            tr_total = sequence['transl'].shape[-1] if 'transl' in sequence else None
            bt_total = sequence['betas'].shape[-1]

            go_per = go_total // num_persons if go_total is not None else None
            bp_per = bp_total // num_persons
            tr_per = tr_total // num_persons if tr_total is not None else None
            bt_per = bt_total // num_persons 

            for p_idx in range(num_persons):
                betas_p = sequence['betas'][..., p_idx*bt_per:(p_idx+1)*bt_per]              # [T, 10]
                trans_p = sequence['transl'][..., p_idx*tr_per:(p_idx+1)*tr_per] if tr_per else None
                go_flat = sequence['global_orient'][..., p_idx*go_per:(p_idx+1)*go_per] if go_per else None
                bp_flat = sequence['body_pose'][..., p_idx*bp_per:(p_idx+1)*bp_per]

                if go_flat is not None:
                    go_rot = _to_rotmat(go_flat, per_dim=go_per, num_joints=1)             # [T, 1, 3, 3]
                else:
                    eye = torch.eye(3).reshape(1, 1, 3, 3).repeat(bp_flat.shape[0], 1, 1, 1)
                    go_rot = eye

                bp_rot = _to_rotmat(bp_flat, per_dim=bp_per, num_joints=21)                # [T, 21, 3, 3]

                rot_stack = torch.cat([go_rot, bp_rot], dim=-3)                             # [T, 22, 3, 3]
                poses_axis = transforms.matrix_to_axis_angle(rot_stack.reshape(-1, 3, 3))   # [T*22, 3]
                poses_axis = poses_axis.reshape(-1, 22*3)                                   # [T, 66]

                poses_axis_pad = torch.cat(
                    [poses_axis, torch.zeros(poses_axis.shape[0], 99, dtype=poses_axis.dtype)], dim=-1
                )  # [T, 165]

                data_dict_p = {
                    'mocap_framerate': min(dataset.target_fps, 30),
                    'gender': sequence['gender'][p_idx],
                    'betas': betas_p[0, :10].detach().cpu().numpy(),
                    'poses': poses_axis_pad.detach().cpu().numpy(),
                }
                if trans_p is not None:
                    data_dict_p['trans'] = trans_p.detach().cpu().numpy()      # [T, 3]

                with open(out_path / f'sample_{idx}_smplx_p{p_idx+1}.npz', 'wb') as f:
                    np.savez(f, **data_dict_p)


    abs_path = out_path.absolute()
    print(f'[Done] Results are at [{abs_path}]')

if __name__ == '__main__':
    rollout_args = tyro.cli(RolloutArgs)
    # TRY NOT TO MODIFY: seeding
    random.seed(rollout_args.seed)
    np.random.seed(rollout_args.seed)
    torch.manual_seed(rollout_args.seed)
    torch.set_default_dtype(torch.float32)
    torch.backends.cudnn.deterministic = rollout_args.torch_deterministic
    device = torch.device(rollout_args.device if torch.cuda.is_available() else "cpu")
    rollout_args.device = device
    
    
    rollout_args.text_prompt = "They greet each another by waving*3; They bow to each another*2; They take steps back together*3; They jump up*3"
    rollout_args.text_prompt_person1 = ""
    rollout_args.text_prompt_person2 = ""
    if rollout_args.n_persons == 3:
        rollout_args.text_prompt_person3 = ""
    denoiser_args, denoiser_model, vae_args, vae_model, data_args = load_mld(rollout_args.denoiser_checkpoint, device)
    denoiser_checkpoint = Path(rollout_args.denoiser_checkpoint)
    save_dir = denoiser_checkpoint.parent / denoiser_checkpoint.name.split('.')[0] / 'rollout'
    save_dir.mkdir(parents=True, exist_ok=True)
    rollout_args.save_dir = save_dir

    diffusion_args = denoiser_args.diffusion_args
    diffusion_args.respacing = rollout_args.respacing
    print('diffusion_args:', asdict(diffusion_args))
    diffusion = create_gaussian_diffusion(diffusion_args)
    
    if not denoiser_args.load_text_embedding:
        if denoiser_args.text_encoder_version == "v1":
            text_encoder = CLIPTextEncoder(denoiser_args.clip_version, clip_device=device)
        elif denoiser_args.text_encoder_version == "v2":
            text_encoder = CLIPTextEncoderV2(denoiser_args.clip_version, clip_final_proj=denoiser_args.clip_final_proj, clip_device=device)
        elif denoiser_args.text_encoder_version == "v3":
            text_encoder = CLIPTextEncoderV3(denoiser_args.clip_version, clip_final_proj=denoiser_args.clip_final_proj, clip_device=device)
        text_encoder.to(device)
        text_encoder_ckpt = torch.load(denoiser_checkpoint, map_location=device)
        text_encoder.load_state_dict(text_encoder_ckpt['text_encoder_state_dict'])
        text_encoder.eval()

    # load initial seed dataset
    motion_repr={'transl': 3,
                'poses_6d': 22 * 6,
                'transl_delta': 3,
                'global_orient_delta_6d': 6,
                'joints': 22 * 3,
                'joints_delta': 22 * 3,} if rollout_args.dataset == 'interhuman' else \
                {'joints': 22 * 3,
                'joints_delta': 22 * 3,
                'body_pose': 21 * 6,
                'feet_contact': 4,}
    
    sequence_path = f'./data/t_pose_interaction_exchangeyz_three.pkl'
    mode = 'merged' if data_args.interaction else 'sep'
    dataset = SinglePrimitiveDataset(cfg_path=data_args.cfg_path,  # cfg path from model checkpoint
                                     dataset_path=data_args.data_dir,  # dataset path from model checkpoint
                                     body_type=data_args.body_type,
                                     sequence_path=sequence_path,
                                     batch_size=rollout_args.batch_size,
                                     device=device,
                                     enforce_gender='male',
                                     enforce_zero_beta=1,
                                     mode=mode,
                                     motion_repr=motion_repr,
                                     padding=data_args.padding,
                                     normalize_relpose=denoiser_args.normalize_relpose,
                                     )


    if len(rollout_args.text_prompt)<256 and Path(rollout_args.text_prompt).exists():
        with open(rollout_args.text_prompt, 'r') as f:
            texts = f.readlines()
            texts = [text.strip() for text in texts]
            for text_prompt in texts:
                print(f'Generating [{text_prompt}]')
                rollout(text_prompt, denoiser_args, denoiser_model, vae_args, vae_model, diffusion, dataset, rollout_args, data_args, text_encoder=text_encoder if not denoiser_args.load_text_embedding else None)
    else:
        text_prompt = rollout_args.text_prompt if not data_args.interaction else {'person1': rollout_args.text_prompt_person1, 'person2': rollout_args.text_prompt_person2, 'interaction': rollout_args.text_prompt}
        if rollout_args.n_persons == 3:
            text_prompt['person3'] = rollout_args.text_prompt_person3
        rollout(text_prompt, denoiser_args, denoiser_model, vae_args, vae_model, diffusion, dataset, rollout_args, data_args, text_encoder=text_encoder if not denoiser_args.load_text_embedding else None)


