import sys
sys.path.insert(0, sys.path[0]+r"/../../../")

import random
import clip
import codecs as cs
from utils.intergen_util import cal_rel_rot
from data_loaders.humanml.data.dataset import *
from utils.word_vectorizer import WordVectorizer

class InterXDataset(WeightedPrimitiveSequenceDataset):
    def __init__(self, dataset_name='interx',
                 dataset_path='./data/Inter-X/seq_data_single_interaction_fps30',
                 cfg_path='./config_files/config_hydra/motion_primitive/interx_h2_f8_r4.yaml',
                 split="train",
                 device='cuda',
                 weight_scheme='uniform',
                 prob_static=0.0,
                 enforce_gender=None,
                 enforce_zero_beta=None,
                 load_data=True,
                 text_tolerance=0.0,
                 body_type='smplx',
                 seed_only=False,
                 use_frame_weights=True,
                 mode='merged', # 'sep' or 'merged'
                 text_sep = False,
                 max_segs = 20,
                 motion_repr = {
                    'body_pose': 55 * 6,
                    'transl': 3,
                    'transl_delta': 3,
                },
                #  motion_repr = {
                #      'joints': 55 * 3,
                #      'joints_delta': 55 * 3,
                #      'body_pose': 55 * 6,
                # },
                 **kwargs):
        self.dataset_name = dataset_name
        self.dataset_path = dataset_path
        self.split = split
        self.device = device
        self.weight_scheme = weight_scheme
        self.prob_static = prob_static
        self.enforce_gender = enforce_gender
        self.enforce_zero_beta = enforce_zero_beta
        print('enforce_gender: ', enforce_gender)
        print('enforce_zero_beta: ', enforce_zero_beta)
        
        self.text_tolerance = text_tolerance
        self.seed_only = seed_only
        self.mode = mode
        self.text_sep = text_sep
        self.max_segs = max_segs
        self.sep_mode = kwargs.get('sep_mode', 0)
        self.padding = kwargs.get('padding', False)
        self.key_list = ['person1', 'person2', 'interaction'] if self.mode=='merged' else ['person1', 'person2']
        self.feet_thre = 0.001
        self.n_joints = 55
        
        self.clip_version = kwargs.get('clip_version', 'ViT-B/32')
        self.load_text_embedding = kwargs.get('load_text_embedding', False)
        self.use_indi_text = kwargs.get('use_indi_text', False)
        self.text_key_list = ['person1', 'person2', 'interaction'] if self.use_indi_text else ['interaction']
        
        self.primitive_utility = PrimitiveUtility(device=self.device, body_type=body_type, motion_repr=motion_repr)
        self.motion_repr = self.primitive_utility.motion_repr

        # cfg_path = Path(dataset_path, 'config.yaml')
        with open(cfg_path, 'r') as f:
            self.cfg = OmegaConf.load(f)
        self.target_fps = self.cfg.fps
        # self.downsample_rate = 120 // self.target_fps
        self.history_length = self.cfg.history_length
        self.future_length = self.cfg.future_length
        self.primitive_length = self.history_length + self.future_length
        self.num_primitive = self.cfg.num_primitive
        if seed_only:
            assert self.num_primitive == 1
        self.seq_length = self.history_length + self.future_length * self.num_primitive + 1

        if load_data:
            with open(pjoin(dataset_path, f'{split}.pkl'), 'rb') as f:
                dataset = pickle.load(f)
            if not self.padding:
                dataset = [data for data in dataset if len(data['motion_p1']['trans']) >= self.seq_length]
            for data in tqdm(dataset, desc='Processing dataset'):
                if self.padding:
                    T = data['motion_p1']['trans'].shape[0]
                    if T < self.seq_length:
                        pad_len = self.seq_length - T
                        for person in ['motion_p1', 'motion_p2']:
                            for key in ['trans', 'poses', 'joints']:
                                last_frame = data[person][key][-1:]
                                padding = np.repeat(last_frame, pad_len, axis=0)
                                data[person][key] = np.concatenate([
                                    data[person][key],
                                    padding
                                ], axis=0)

                            # padding_mask
                            data[person]['padding_mask'] = np.concatenate([
                                np.zeros(T, dtype=np.bool_),
                                np.ones(pad_len, dtype=np.bool_)
                            ], axis=0)
                    else:
                        for person in ['motion_p1', 'motion_p2']:
                            data[person]['padding_mask'] = np.zeros(T, dtype=np.bool_)
                        
                def convert_motion(motion, gender, enforce_zero_beta):
                    betas = torch.from_numpy(motion['betas'].astype(np.float32))
                    if enforce_zero_beta:
                        betas = torch.zeros_like(betas)
                    transl = torch.from_numpy(motion['trans'].astype(np.float32))
                    poses = torch.from_numpy(motion['poses'].astype(np.float32))
                    global_orient, body_pose, jaw_pose, leye_pose, reye_pose, pose_lhand, pose_rhand = \
                        transforms.axis_angle_to_matrix(poses[:, :3]), \
                        transforms.axis_angle_to_matrix(poses[:, 3:66].reshape(-1, 21, 3)), \
                        transforms.axis_angle_to_matrix(poses[:, 66:69]), \
                        transforms.axis_angle_to_matrix(poses[:, 69:72]), \
                        transforms.axis_angle_to_matrix(poses[:, 72:75]), \
                        transforms.axis_angle_to_matrix(poses[:, 75:120].reshape(-1, 15, 3)), \
                        transforms.axis_angle_to_matrix(poses[:, 120:].reshape(-1, 15, 3))            
                    pelvis_delta = torch.from_numpy(motion['pelvis_delta'].astype(np.float32))              # [3]
                    joints = torch.from_numpy(motion['joints'].astype(np.float32))                          # [T, 55, 3]
                    result = {
                        'gender': gender,
                        'betas': betas,
                        'transl': transl,
                        'global_orient': global_orient,
                        'body_pose': body_pose,
                        'jaw_pose': jaw_pose,
                        'left_eye_pose': leye_pose,
                        'right_eye_pose': reye_pose,
                        'left_hand_pose': pose_lhand,
                        'right_hand_pose': pose_rhand,
                        'pelvis_delta': pelvis_delta,
                        'joints': joints,
                    }
                    if self.padding:
                        result['padding_mask'] = motion['padding_mask']
                    return result
                gender_p1 = self.enforce_gender if self.enforce_gender is not None else data['motion_p1']['gender']
                gender_p2 = self.enforce_gender if self.enforce_gender is not None else data['motion_p2']['gender']
                data['motion_p1'] = convert_motion(data['motion_p1'], gender_p1, self.enforce_zero_beta)
                data['motion_p2'] = convert_motion(data['motion_p2'], gender_p2, self.enforce_zero_beta)
            
            print('num of sequences: ', len(dataset))
            
            # assign sampling weights to each sequence
            for data in dataset:
                if 'uniform' in weight_scheme:
                    data['weight'] = 1.0
                elif 'length' in weight_scheme:
                    if self.padding and len(data['motion_p1']['transl'])==self.seq_length:
                        data['weight'] = len(data['motion_p1']['transl'])-sum(data['motion_p1']['padding_mask'])
                    else:
                        data['weight'] = len(data['motion_p1']['transl'])
            print('finish first assigning seq weights')

            # overfit using one sequence
            # if 'overfit' in weight_scheme:
            #     seq_id = int(weight_scheme.split('overfit:')[-1].split('_')[0])
            #     for idx, data in enumerate(dataset):
            #         if idx == seq_id:
            #             data['weight'] = 1.0
            #         else:
            #             data['weight'] = 0.0

            seq_weights = np.array([data['weight'] for data in dataset])
            seq_weights = seq_weights / seq_weights.sum()

            self.dataset = dataset
            self.seq_weights = seq_weights
        self._curr_test_index = 0
        
        # load or calc mean and std
        self.tensor_mean_device_dict = {}
        file_name = f'mean_std_h{self.history_length}_f{self.future_length}'
        
        suffix = '_padding' if self.padding else ''
        mean_std_path = Path(dataset_path, f'{file_name}{suffix}.pkl')
        
        if mean_std_path.exists():
            print(f'loading mean and std from {mean_std_path}')
            with open(mean_std_path, 'rb') as f:
                self.tensor_mean, self.tensor_std = pickle.load(f)  # [1, 1, D]
        else:
            assert self.split == 'train'
            print('calculating mean and std using train split')
            result = self.calc_mean_std()
            self.tensor_mean, self.tensor_std = result

            with open(mean_std_path, 'wb') as f:
                pickle.dump((self.tensor_mean.detach().cpu(), self.tensor_std.detach().cpu()), f)
            
    
        # load clip model, get train text embeddings
        if self.load_text_embedding:
            self.load_and_freeze_clip(clip_version=self.clip_version, device=self.device)
            self.dim_embed_text = self.clip_model.ln_final.normalized_shape[0]
            suffix = '' if self.clip_version == 'ViT-B/32' else f"_{self.clip_version.replace('/', '')}"
            self.embedding_path = {}
            embedding_path = {}
            for key_type in self.text_key_list:
                if text_sep:
                    self.embedding_path[key_type] = embedding_path[key_type] = Path(dataset_path, f'{split}_{key_type}_text_embedding_dict_textsep_sepmode{self.sep_mode}{suffix}.pkl')
                else:
                    self.embedding_path[key_type] = embedding_path[key_type] = Path(dataset_path, f'{split}_{key_type}_text_embedding_dict{suffix}.pkl')
            self.text_embedding_dict = {}
            if text_sep:
                self.text_mask_dict = {}
        
            for key_type in self.text_key_list:
                if embedding_path[key_type].exists():
                    print(f"Loading text_{key_type} embeddings from {embedding_path[key_type]}!")
                    with open(embedding_path[key_type], 'rb') as f:
                        self.text_embedding_dict[key_type] = pickle.load(f)
                    if text_sep:
                        with open(Path(str(embedding_path[key_type]).replace('text_embedding_dict', 'text_mask_dict')), 'rb') as f:
                            self.text_mask_dict[key_type] = pickle.load(f)
                else:
                    print('Calculating text embeddings')
                    raw_texts = []
                    for data in self.dataset:
                        if f'frame_labels_{key_type}' in data:
                            raw_texts.extend([seg['proc_label'] for seg in data['frame_labels_' + key_type]])

                    raw_texts = list(set(raw_texts))
                    num_texts = len(raw_texts)
                    print(f'num of unique texts_{key_type}: ', len(raw_texts))
                        
                    # get text embeddings by batch
                    text_embeddings = []
                    text_mask = []
                    batch_start_idx = 0
                    while batch_start_idx < num_texts:
                        batch_end_idx = min(batch_start_idx + 256, num_texts)
                        text_embeddings_temp = self.encode_text(raw_texts[batch_start_idx:batch_end_idx], text_sep=text_sep, max_segs=max_segs, sep_mode=self.sep_mode)
                        if text_sep:
                            text_embeddings.append(text_embeddings_temp[0])
                            text_mask.append(text_embeddings_temp[1])
                        else:
                            text_embeddings.append(text_embeddings_temp)
                        batch_start_idx = batch_end_idx
                    text_embeddings = torch.cat(text_embeddings, dim=0).detach().cpu().numpy()
                
                    self.text_embedding_dict[key_type] = {raw_texts[idx]: text_embeddings[idx] for idx in range(num_texts)}
                    if text_sep:
                        self.text_embedding_dict[key_type][''] = np.zeros((self.max_segs, self.dim_embed_text)).astype(np.float32)
                    else:
                        self.text_embedding_dict[key_type][''] = np.zeros(self.dim_embed_text).astype(np.float32)  # for empty text have zero embedding, compatible with mdm text masking
                    with open(embedding_path[key_type], 'wb') as f:
                        pickle.dump(self.text_embedding_dict[key_type], f)
                    if text_sep:
                        text_mask = torch.cat(text_mask, dim=0).detach().cpu().numpy()
                        self.text_mask_dict[key_type] = {raw_texts[idx]: text_mask[idx] for idx in range(num_texts)}
                        self.text_mask_dict[key_type][''] = np.zeros(max_segs).astype(np.bool_)  # for empty text have zero embedding, compatible with mdm text masking
                        with open(Path(str(embedding_path[key_type]).replace('text_embedding_dict', 'text_mask_dict')), 'wb') as f:
                            pickle.dump(self.text_mask_dict[key_type], f)
                
                for key in self.text_embedding_dict[key_type]:
                    self.text_embedding_dict[key_type][key] = torch.from_numpy(self.text_embedding_dict[key_type][key]).to(dtype=torch.float32, device=self.device)
                    if text_sep:
                        self.text_mask_dict[key_type][key] = torch.from_numpy(self.text_mask_dict[key_type][key]).to(dtype=torch.bool, device=self.device)

    def load_and_freeze_clip(self, clip_version, device='cpu'):
        self.clip_model, _= clip.load(clip_version, device=device,
                                                jit=False)  # Must set jit=False for training
        clip.model.convert_weights(self.clip_model)  # Actually this line is unnecessary since clip by default already on float16

        # Freeze CLIP weights
        self.clip_model.eval()
        for p in self.clip_model.parameters():
            p.requires_grad = False

    def encode_text(self, raw_text, force_empty_zero=True, text_sep=False, max_segs = 20, sep_mode=0):
        import pandas as pd
        device = next(self.clip_model.parameters()).device
        embed_dim = self.dim_embed_text
        batch_size = len(raw_text)

        if not text_sep:
            with torch.no_grad():
                texts = clip.tokenize(raw_text, truncate=True).to(device)  # [B, context_length]
                text_embedding = self.clip_model.encode_text(texts).float()  # [B, 512]
                if force_empty_zero:
                    empty_text = [t == '' for t in raw_text]
                    text_embedding[empty_text, :] = 0
                return text_embedding
                
        raw_series = pd.Series(raw_text).str.strip().str.rstrip('.')
        if sep_mode == 0:
            split_df = raw_series.str.split(r'[,.]', n=max_segs - 1, expand=True)
        elif sep_mode == 1:
            split_df = raw_series.str.split(r'\band\b|\bwhile\b|,|\.', n=max_segs - 1, expand=True)
        split_df = split_df.fillna('').astype(str).applymap(str.strip)

        split_df = split_df.reindex(columns=range(max_segs), fill_value='')
        
        segs_matrix = split_df.values
        segs_flat = segs_matrix.reshape(-1).tolist()

        text_mask = (segs_matrix == '').astype(bool)
        text_mask = torch.tensor(text_mask, dtype=torch.bool, device=device)

        tokenized = clip.tokenize(segs_flat, truncate=True).to(device)  # [B*max_segs, context_length]
        text_embedding = self.clip_model.encode_text(tokenized).float()      # [B*max_segs, 512]
        text_embedding = text_embedding.view(batch_size, max_segs, embed_dim)  # [B, max_segs, 512]

        if force_empty_zero:
            text_embedding[text_mask] = 0

        return text_embedding, text_mask
            
    def get_batch_idx(self, batch_size=8):
        if self.split == 'test':
            start_idx = self._curr_test_index
            end_idx = start_idx + batch_size
            batch_idx = np.arange(start_idx, min(end_idx, len(self.dataset)))
            self._curr_test_index = end_idx if end_idx < len(self.dataset) else 0
            return batch_idx
        else:
            batch_idx = np.random.choice(len(self.dataset), size=batch_size, replace=True, p=self.seq_weights)
            return batch_idx

    def update_text_embedding_dict(self, new_texts, key_type, text_sep=False, max_segs=20):
        new_text_embeddings = self.encode_text(new_texts, text_sep=text_sep, max_segs=max_segs, sep_mode=self.sep_mode)
        for idx, text in enumerate(new_texts):
            if text_sep:
                self.text_embedding_dict[key_type][text] = new_text_embeddings[0][idx]
                self.text_mask_dict[key_type][text] = new_text_embeddings[1][idx]
            else:
                self.text_embedding_dict[key_type][text] = new_text_embeddings[idx]

    def calc_mean_std(self, batch_size=512):
        if self.future_length == 1:
            batch_size = min(batch_size, 64)
        all_mp_data = []
        for seq_data in self.dataset:
            motion_data_p1 = seq_data['motion_p1']
            num_frames = motion_data_p1['body_pose'].shape[0]
            primitive_data_list = []
            for start_frame in range(0, num_frames - self.primitive_length, self.future_length):
                end_frame = start_frame + self.primitive_length
                primitive_data_list.append(self.get_primitive(seq_data, start_frame, end_frame, skip_text=True))
                
            primitive_dict = {}
            for person in ['person1', 'person2']:
                primitive_dict[person] = {}
                primitive_dict[person]['gender'] = {primitive_data_list[0]['primitive_dict'][person]['gender']}
                for key in ['betas', 'transl', 'global_orient', 'body_pose', 'jaw_pose', 'left_eye_pose', 'right_eye_pose', 'left_hand_pose', 'right_hand_pose',
                            'transf_rotmat', 'transf_transl', 'pelvis_delta', 'joints']:
                    primitive_dict[person][key] = torch.cat([data['primitive_dict'][person][key] for data in primitive_data_list], dim=0)
                if self.padding:
                    primitive_dict[person]['primitive_padding_mask'] = torch.cat([data['primitive_dict'][person]['primitive_padding_mask'] for data in primitive_data_list], dim=0)
                primitive_dict[person] = tensor_dict_to_device(primitive_dict[person], self.device)

            # split primitive_dict into batches
            batch_start_idx = 0
            while batch_start_idx < len(primitive_dict['person1']['body_pose']):
                batch_primitive_dict = {}
                canonicalized_primitive_dict = {}
                batch_end_idx = min(batch_start_idx + batch_size, len(primitive_dict['person1']['body_pose']))
                for person in ['person1', 'person2']:
                    batch_primitive_dict[person] = {}
                    batch_primitive_dict[person] = {key: primitive_dict[person][key][batch_start_idx:batch_end_idx] for key in ['betas', 'transl', 'global_orient', 'body_pose', 'jaw_pose', 'left_eye_pose', 'right_eye_pose', 'left_hand_pose', 'right_hand_pose', 'transf_rotmat', 'transf_transl', 'pelvis_delta', 'joints']}
                    batch_primitive_dict[person]['gender'] = primitive_dict[person]['gender']
                    _, _, canonicalized_primitive_dict[person] = self.primitive_utility.canonicalize(copy.deepcopy(batch_primitive_dict[person]), use_predicted_joints=True)
                
                feature_dict = {}
                for person in ['person1', 'person2']:
                    feature_dict[person] = self.primitive_utility.calc_features(canonicalized_primitive_dict[person], use_predicted_joints=True)
                    if self.primitive_utility.feature_dim == 276:
                        feature_dict[person]['transl'] = feature_dict[person]['transl'][:, :-1, :]  # [B*num_mp, T, 3]
                        feature_dict[person]['poses_6d'] = feature_dict[person]['poses_6d'][:, :-1, :]  # [B*num_mp, T, 66]
                        feature_dict[person]['joints'] = feature_dict[person]['joints'][:, :-1, :]  # [B*num_mp, T, 22 * 3]
                    elif self.primitive_utility.feature_dim == 56 * 6:
                        feature_dict[person]['transl'] = feature_dict[person]['transl'][:, :-1, :]  # [B*num_mp, T, 3]
                        feature_dict[person]['body_pose'] = feature_dict[person]['body_pose'][:, :-1, :]  # [B*num_mp, T, 55*6]
                    elif self.primitive_utility.feature_dim == 55 * 12:
                        feature_dict[person]['joints'] = feature_dict[person]['joints'][:, :-1, :]  # [B*num_mp, T, 55 * 3]
                        feature_dict[person]['body_pose'] = feature_dict[person]['body_pose'][:, :-1, :]  # [B*num_mp, T, 55*6]
                    motion_tensor = self.dict_to_tensor(feature_dict[person]).detach().cpu()    # [num_primitive, T, D]
                    if self.padding:
                        mask_slice = primitive_dict[person]['primitive_padding_mask'][batch_start_idx:batch_end_idx, -1]  # [B]
                        valid_indices = torch.nonzero(~mask_slice, as_tuple=True)[0].detach().cpu()  # select valid
                        motion_tensor = motion_tensor[valid_indices]
                    all_mp_data.append(motion_tensor)
                batch_start_idx = batch_end_idx

        all_mp_data = torch.cat(all_mp_data, dim=0)                 # [2*N, T, D]
        tensor_mean = all_mp_data.mean(dim=[0, 1], keepdim=True)    # [1, 1, D]
        tensor_std = all_mp_data.std(dim=[0, 1], keepdim=True)      # [1, 1, D]
        return tensor_mean.to(self.device), tensor_std.to(self.device)

    def get_primitive(self, seq_data, start_frame, end_frame, skip_text=False):
        """end_frame included"""
        primitive_dict = {}
        for person, motion_data in zip(['person1', 'person2'], [seq_data['motion_p1'], seq_data['motion_p2']]):
            primitive_dict[person] = {
                'gender': motion_data['gender'],
                'betas': motion_data['betas'].expand(1, self.primitive_length + 1, 10),
                'transl': motion_data['transl'][start_frame:end_frame + 1].unsqueeze(0),
                'global_orient': motion_data['global_orient'][start_frame:end_frame + 1].unsqueeze(0),
                'body_pose': motion_data['body_pose'][start_frame:end_frame + 1].unsqueeze(0),
                'jaw_pose': motion_data['jaw_pose'][start_frame:end_frame + 1].unsqueeze(0),
                'left_eye_pose': motion_data['left_eye_pose'][start_frame:end_frame + 1].unsqueeze(0),
                'right_eye_pose': motion_data['right_eye_pose'][start_frame:end_frame + 1].unsqueeze(0),
                'left_hand_pose': motion_data['left_hand_pose'][start_frame:end_frame + 1].unsqueeze(0),
                'right_hand_pose': motion_data['right_hand_pose'][start_frame:end_frame + 1].unsqueeze(0),
                'pelvis_delta': motion_data['pelvis_delta'].unsqueeze(0),
                'joints': motion_data['joints'][start_frame:end_frame + 1].unsqueeze(0),
                'transf_rotmat': torch.eye(3).unsqueeze(0),
                'transf_transl': torch.zeros(1, 1, 3),
            }
            if self.padding:
                padding_mask_full = seq_data[f'motion_p{person[-1]}']['padding_mask'][start_frame:end_frame + 1]  # shape [T+1]
                history_mask = torch.tensor(padding_mask_full[:self.history_length], dtype=torch.bool)
                future_mask = padding_mask_full[self.history_length:-1]
                future_flag = torch.tensor(future_mask.any(), dtype=torch.bool)
                primitive_dict[person]['primitive_padding_mask'] = torch.cat([history_mask, future_flag.unsqueeze(0)], dim=0).unsqueeze(0) # (1, history_length + 1)

        texts = {key: [] for key in self.text_key_list}
        for key_type in self.text_key_list:
            if not skip_text and f'frame_labels_{key_type}' in seq_data:
                future_start = (start_frame + self.history_length) / self.target_fps
                future_end = (start_frame + self.history_length + self.future_length - 1) / self.target_fps
                # print('text tolerance: ', self.text_tolerance)
                for seg in seq_data[f'frame_labels_{key_type}']:
                    if have_overlap([seg['start_t'], seg['end_t']], [future_start - self.text_tolerance, future_end + self.text_tolerance]):
                        texts[key_type].append(seg['proc_label'])

        output = {}
        for key_type in self.text_key_list:
            output['text_'+key_type] = random.choice(texts[key_type]) if len(texts[key_type]) > 0 else ''
        output['primitive_dict'] = primitive_dict
        return output

    def get_batch(self, batch_size=8):
        self.time = time.time()
        seq_list = []
        batch_idx = self.get_batch_idx(batch_size)
        add_key_list = ['gender']
        cat_key_list = ['betas', 'motion_tensor_normalized', 'history_motion', 'history_mask', 'transf_rotmat', 'transf_transl']
        if self.padding:
            cat_key_list.append('primitive_padding_mask')
        if self.use_indi_text:
            add_key_list.append('texts')
        if self.load_text_embedding:
            add_key_list.append('texts')
            cat_key_list.append('text_embedding')
            if self.text_sep:
                cat_key_list.append('text_mask')
        
        for seq_idx in batch_idx:
            seq_data = dict(self.dataset[seq_idx])
            # exchange person1 and person2
            if random.random() < 0.5:
                seq_data['motion_p1'], seq_data['motion_p2'] = seq_data['motion_p2'], seq_data['motion_p1']
                if self.use_indi_text:
                    seq_data['frame_labels_person1'], seq_data['frame_labels_person2'] = seq_data['frame_labels_person2'], seq_data['frame_labels_person1']
            num_frames = len(seq_data['motion_p1']['transl'])
            if 'text' in self.weight_scheme:
                start_frame = random.choices(range(num_frames - self.seq_length + 1), weights=seq_data['frame_weights'], k=1)[0]
            else:
                start_frame = random.randint(0, num_frames - self.seq_length)  # [0, num_frames - seq_length], right end inclusive
            primitive_data_list = []
            for frame_idx in range(start_frame, start_frame + self.seq_length - self.primitive_length, self.future_length):
                primitive_data = self.get_primitive(seq_data, frame_idx, frame_idx + self.primitive_length)
                primitive_data_list.append(primitive_data)
            seq_list.append(primitive_data_list)

        # sort batch by gender
        batch = None
        for gender in ['female', 'male', 'neutral']:
            gender_idx = [idx for idx in range(len(seq_list)) if seq_list[idx][0]['primitive_dict']['person1']['gender'] == gender]
            if len(gender_idx) == 0:
                continue
            gender_seq_list = [seq_list[i] for i in gender_idx]
            gender_batch_size = len(gender_idx)
            gender_batch = {} if self.mode == 'merged' else []
            
            gender_seq_texts = {key_type: None for key_type in self.text_key_list}
            gender_seq_dict = None
            for primitive_idx in range(self.num_primitive):
                primitive_dict = {}
                for person in ['person1', 'person2']:
                    primitive_dict[person] = {}
                    primitive_dict[person]['gender'] = gender
                    for key in ['betas', 'transl', 'global_orient', 'body_pose', 'transf_rotmat', 'transf_transl', 'pelvis_delta', 'joints', 
                                'jaw_pose', 'left_eye_pose', 'right_eye_pose', 'left_hand_pose', 'right_hand_pose']:
                        primitive_dict[person][key] = torch.cat([mp_seq[primitive_idx]['primitive_dict'][person][key] for mp_seq in gender_seq_list], dim=0)
                    if self.padding:
                        primitive_dict[person]['primitive_padding_mask'] = torch.cat([mp_seq[primitive_idx]['primitive_dict'][person]['primitive_padding_mask'] for mp_seq in gender_seq_list], dim=0)
                primitive_texts = {}
                for key_type in self.text_key_list:
                    primitive_texts[key_type] = [mp_seq[primitive_idx]['text_'+key_type] for mp_seq in gender_seq_list]
                    gender_seq_texts[key_type] = primitive_texts[key_type] if gender_seq_texts[key_type] is None else gender_seq_texts[key_type] + primitive_texts[key_type]
                
                if gender_seq_dict is None:
                    gender_seq_dict = primitive_dict
                else:
                    for person in ['person1', 'person2']:
                        for key in ['betas', 'transl', 'global_orient', 'body_pose', 'transf_rotmat', 'transf_transl', 'pelvis_delta', 'joints', 
                                    'jaw_pose', 'left_eye_pose', 'right_eye_pose', 'left_hand_pose', 'right_hand_pose']:
                            gender_seq_dict[person][key] = torch.cat([gender_seq_dict[person][key], primitive_dict[person][key]], dim=0)
                        if self.padding:
                            gender_seq_dict[person]['primitive_padding_mask'] = torch.cat([gender_seq_dict[person]['primitive_padding_mask'], 
                                                                                           primitive_dict[person]['primitive_padding_mask']], dim=0)

            canonicalized_primitive_dict = {}
            transf_rotmat, transf_transl = {}, {}
            for person in ['person1', 'person2']:
                gender_seq_dict[person] = tensor_dict_to_device(gender_seq_dict[person], self.device)
                transf_rotmat[person], transf_transl[person], canonicalized_primitive_dict[person] = self.primitive_utility.canonicalize(copy.deepcopy(gender_seq_dict[person]), use_predicted_joints=True)            
            if self.mode == 'merged':
                # rel_pose
                rel_rotmat, rel_transl, rel_pose = {}, {}, {}
                rel_rotmat['b2a'], rel_transl['b2a'] = self.primitive_utility.compute_rel_transform_B_in_A(transf_rotmat['person1'], transf_transl['person1'], transf_rotmat['person2'], transf_transl['person2'])
                rel_rotmat['a2b'], rel_transl['a2b'] = self.primitive_utility.compute_rel_transform_B_in_A(transf_rotmat['person2'], transf_transl['person2'], transf_rotmat['person1'], transf_transl['person1'])
                for rel in ['b2a', 'a2b']:
                    rel_rotmat[rel] = transforms.matrix_to_rotation_6d(rel_rotmat[rel])
                    rel_pose[rel] = torch.cat([rel_rotmat[rel], rel_transl[rel].squeeze(1)], dim=-1)  # [B*num_mp, 6+3]
                        
            # calc features
            feature_dict = {}
            for person in ['person1', 'person2']:
                feature_dict[person] = self.primitive_utility.calc_features(canonicalized_primitive_dict[person], use_predicted_joints=True)
                if self.primitive_utility.feature_dim == 276:
                    feature_dict[person]['transl'] = feature_dict[person]['transl'][:, :-1, :]  # [B*num_mp, T, 3]
                    feature_dict[person]['poses_6d'] = feature_dict[person]['poses_6d'][:, :-1, :]  # [B*num_mp, T, 66]
                    feature_dict[person]['joints'] = feature_dict[person]['joints'][:, :-1, :]  # [B*num_mp, T, 22 * 3]
                elif self.primitive_utility.feature_dim == 56 * 6:
                    feature_dict[person]['transl'] = feature_dict[person]['transl'][:, :-1, :]  # [B*num_mp, T, 3]
                    feature_dict[person]['body_pose'] = feature_dict[person]['body_pose'][:, :-1, :]  # [B*num_mp, T, 55*6]
                elif self.primitive_utility.feature_dim == 55 * 12:
                    feature_dict[person]['joints'] = feature_dict[person]['joints'][:, :-1, :]  # [B*num_mp, T, 55 * 3]
                    feature_dict[person]['body_pose'] = feature_dict[person]['body_pose'][:, :-1, :]  # [B*num_mp, T, 55*6]   
                            
            if self.mode == 'merged':
                for person in ['person1', 'person2']:
                    gender_batch[person] = []
                    motion_tensor_normalized = self.normalize(self.dict_to_tensor(feature_dict[person]))     # [B*num_mp, T, D]
                    motion_tensor_normalized = motion_tensor_normalized.permute(0, 2, 1).unsqueeze(2)           # [B*num_mp, D, 1, T]
                    history_mask = torch.zeros_like(motion_tensor_normalized, dtype=torch.bool, device=self.device)
                    history_mask[..., :self.cfg.history_length] = True
                    history_motion = torch.zeros_like(motion_tensor_normalized, dtype=torch.float32, device=self.device)
                    history_motion[..., :self.cfg.history_length] = motion_tensor_normalized[..., :self.cfg.history_length]

                    for primitive_idx in range(self.num_primitive):
                        start_idx = primitive_idx * gender_batch_size
                        end_idx = (primitive_idx + 1) * gender_batch_size
                        if self.use_indi_text:
                            primitive_texts = gender_seq_texts[person][start_idx:end_idx]
                            if self.load_text_embedding:
                                unseen_texts = [text for text in primitive_texts if text not in self.text_embedding_dict[person]]
                            if len(unseen_texts) > 0:
                                self.update_text_embedding_dict(unseen_texts, person, text_sep=self.text_sep, max_segs=self.max_segs)
                            text_embedding = torch.stack([self.text_embedding_dict[person][text] for text in primitive_texts], dim=0)  # [B, 512]
                            if self.text_sep:
                                text_mask = torch.stack([self.text_mask_dict[person][text] for text in primitive_texts], dim=0)  # [B, max_segs]
                            else:
                                text_mask = None
               
                        gender_batch[person].append({
                                'gender': [gender_seq_dict[person]['gender']] * gender_batch_size,
                                'betas': gender_seq_dict[person]['betas'][start_idx:end_idx, :-1, :10],
                                'motion_tensor_normalized': motion_tensor_normalized[start_idx:end_idx, ...], # [B, D, 1, T]
                                'history_motion': history_motion[start_idx:end_idx, ...],
                                'history_mask': history_mask[start_idx:end_idx, ...],
                                'transf_rotmat': transf_rotmat[person][start_idx:end_idx, ...],
                                'transf_transl': transf_transl[person][start_idx:end_idx, ...],
                                'history_length': self.cfg.history_length,
                                'future_length': self.cfg.future_length,
                            })
                        if self.use_indi_text:
                            gender_batch[person][-1]['texts'] = primitive_texts
                            if self.load_text_embedding:
                                gender_batch[person][-1]['text_embedding'] = text_embedding
                                gender_batch[person][-1]['text_mask'] = text_mask                            
                        if self.padding:
                            gender_batch[person][-1]['primitive_padding_mask'] = gender_seq_dict[person]['primitive_padding_mask'][start_idx:end_idx, ...]
                gender_batch['interaction'] = []
                for primitive_idx in range(self.num_primitive):        
                    start_idx = primitive_idx * gender_batch_size
                    end_idx = (primitive_idx + 1) * gender_batch_size
                    primitive_texts = gender_seq_texts['interaction'][start_idx:end_idx]
                    if self.load_text_embedding:
                        unseen_texts = [text for text in primitive_texts if text not in self.text_embedding_dict['interaction']]
                        if len(unseen_texts) > 0:
                            self.update_text_embedding_dict(unseen_texts, 'interaction', text_sep=self.text_sep, max_segs=self.max_segs)
                        text_embedding = torch.stack([self.text_embedding_dict['interaction'][text] for text in primitive_texts], dim=0)  # [B, 512]
                        if self.text_sep:
                            text_mask = torch.stack([self.text_mask_dict['interaction'][text] for text in primitive_texts], dim=0)  # [B, max_segs]
                        else:
                            text_mask = None
                    gender_batch['interaction'].append({
                            'texts': primitive_texts,
                            'rel_pose_b2a': rel_pose['b2a'][start_idx:end_idx],
                            'rel_pose_a2b': rel_pose['a2b'][start_idx:end_idx],
                        })
                    if self.load_text_embedding:
                        gender_batch['interaction'][-1]['text_embedding'] = text_embedding
                        gender_batch['interaction'][-1]['text_mask'] = text_mask
                   
            elif self.mode == 'sep':
                for person in ['person1', 'person2']:
                    motion_tensor_normalized = self.normalize(self.dict_to_tensor(feature_dict[person]))
                    motion_tensor_normalized = motion_tensor_normalized.permute(0, 2, 1).unsqueeze(2)                   # [B*num_mp, D, 1, T]
                    history_mask = torch.zeros_like(motion_tensor_normalized, dtype=torch.bool, device=self.device)
                    history_mask[..., :self.cfg.history_length] = True
                    history_motion = torch.zeros_like(motion_tensor_normalized, dtype=torch.float32, device=self.device)
                    history_motion[..., :self.cfg.history_length] = motion_tensor_normalized[..., :self.cfg.history_length]

                    for primitive_idx in range(self.num_primitive):
                        start_idx = primitive_idx * gender_batch_size
                        end_idx = (primitive_idx + 1) * gender_batch_size
                        if self.use_indi_text:
                            primitive_texts = gender_seq_texts[person][start_idx:end_idx]
                            if self.load_text_embedding:
                                unseen_texts = [text for text in primitive_texts if text not in self.text_embedding_dict[person]]
                                if len(unseen_texts) > 0:
                                    self.update_text_embedding_dict(unseen_texts, person, text_sep=self.text_sep, max_segs=self.max_segs)
                                    # new_text_embeddings = encode_text(self.clip_model, unseen_texts)
                                    # for idx, text in enumerate(unseen_texts):
                                    #     self.text_embedding_dict[person][text] = new_text_embeddings[idx]
                                text_embedding = torch.stack([self.text_embedding_dict[person][text] for text in primitive_texts], dim=0)  # [B, 512]
                                if self.text_sep:
                                    text_mask = torch.stack([self.text_mask_dict[person][text] for text in primitive_texts], dim=0)  # [B, max_segs]
                                else:
                                    text_mask = None
                        else:
                            primitive_texts = gender_seq_texts['interaction'][start_idx:end_idx]
                            if self.load_text_embedding:
                                unseen_texts = [text for text in primitive_texts if text not in self.text_embedding_dict['interaction']]
                                if len(unseen_texts) > 0:
                                    self.update_text_embedding_dict(unseen_texts, 'interaction', text_sep=self.text_sep, max_segs=self.max_segs)
                                text_embedding = torch.stack([self.text_embedding_dict['interaction'][text] for text in primitive_texts], dim=0)  # [B, 512]
                                if self.text_sep:
                                    text_mask = torch.stack([self.text_mask_dict['interaction'][text] for text in primitive_texts], dim=0)  # [B, max_segs]
                                else:
                                    text_mask = None

                        gender_batch.append(
                            {
                                'gender': [gender_seq_dict[person]['gender']] * gender_batch_size,
                                'betas': gender_seq_dict[person]['betas'][start_idx:end_idx, :-1, :10],
                                'motion_tensor_normalized': motion_tensor_normalized[start_idx:end_idx, ...], # [B, D, 1, T]
                                'history_motion': history_motion[start_idx:end_idx, ...],
                                'history_mask': history_mask[start_idx:end_idx, ...],
                                'transf_rotmat': transf_rotmat[person][start_idx:end_idx, ...],
                                'transf_transl': transf_transl[person][start_idx:end_idx, ...],
                                'history_length': self.cfg.history_length,
                                'future_length': self.cfg.future_length,
                            }
                        )
                        if self.use_indi_text:
                            gender_batch[-1]['texts'] = primitive_texts
                        if self.load_text_embedding:
                            gender_batch[-1]['texts'] = primitive_texts
                            gender_batch[-1]['text_embedding'] = text_embedding
                            gender_batch[-1]['text_mask'] = text_mask
                        if self.padding:
                            gender_batch[-1]['primitive_padding_mask'] = gender_seq_dict[person]['primitive_padding_mask'][start_idx:end_idx, ...]
                selector = torch.cat([torch.ones(gender_batch_size), torch.zeros(gender_batch_size)])
                selector = selector[torch.randperm(2 * gender_batch_size)]
                
                front_group, back_group = {}, {}
                for key in add_key_list:
                    front_group[key], back_group[key] = [], []
                    for d in gender_batch[:self.num_primitive]:
                        front_group[key] += d[key]
                    for d in gender_batch[self.num_primitive:]:
                        back_group[key] += d[key]
                for key in cat_key_list:
                    front_group[key] = torch.cat([d[key] for d in gender_batch[:self.num_primitive]], dim=0)
                    back_group[key] = torch.cat([d[key] for d in gender_batch[self.num_primitive:]], dim=0)

                front_indices = torch.nonzero(selector[:gender_batch_size], as_tuple=True)[0]  
                back_indices = torch.nonzero(selector[gender_batch_size:], as_tuple=True)[0]  

                selected_batch = []
                for i in range(self.num_primitive):
                    selected_dict = {'history_length': self.cfg.history_length,
                                'future_length': self.cfg.future_length,}
                    for key in front_group.keys():    
                        if key in add_key_list:
                            selected_front = [front_group[key][i] for i in front_indices + i * gender_batch_size] 
                            selected_back = [back_group[key][i] for i in back_indices + i * gender_batch_size]
                            selected_dict[key] = selected_front + selected_back
                        elif key in cat_key_list:
                            selected_front = front_group[key][front_indices + i * gender_batch_size] 
                            selected_back = back_group[key][back_indices + i * gender_batch_size]
                            selected_dict[key] = torch.cat([selected_front, selected_back], dim=0)  
                    selected_batch.append(selected_dict)
                gender_batch = selected_batch
                            
            if batch is None:
                batch = gender_batch
            else:  # concatenate different gender batch
                for primitive_idx in range(self.num_primitive):
                    if self.mode == 'merged':
                        for key_type in self.key_list:
                            if key_type != 'interaction':
                                for key in add_key_list:
                                    batch[key_type][primitive_idx][key] = batch[key_type][primitive_idx][key] + gender_batch[key_type][primitive_idx][key]
                                for key in cat_key_list:
                                    batch[key_type][primitive_idx][key] = torch.cat([batch[key_type][primitive_idx][key], gender_batch[key_type][primitive_idx][key]], dim=0)
                            else:
                                for key in ['texts']:
                                    batch[key_type][primitive_idx][key] = batch[key_type][primitive_idx][key] + gender_batch[key_type][primitive_idx][key]
                                for key in ['rel_pose_b2a', 'rel_pose_a2b']:
                                    batch[key_type][primitive_idx][key] = torch.cat([batch[key_type][primitive_idx][key], gender_batch[key_type][primitive_idx][key]], dim=0)
                                if self.load_text_embedding:
                                    batch[key_type][primitive_idx]['text_embedding'] = torch.cat([batch[key_type][primitive_idx]['text_embedding'], gender_batch[key_type][primitive_idx]['text_embedding']], dim=0)
                                    batch[key_type][primitive_idx]['text_mask'] = torch.cat([batch[key_type][primitive_idx]['text_mask'], gender_batch[key_type][primitive_idx]['text_mask']], dim=0)   

                    else:
                        for key in add_key_list:
                            batch[primitive_idx][key] = batch[primitive_idx][key] + gender_batch[primitive_idx][key]
                        for key in cat_key_list:
                            batch[primitive_idx][key] = torch.cat([batch[primitive_idx][key], gender_batch[primitive_idx][key]], dim=0)
            # if self.mode == 'merged':
            #     if random.random() < 0.5:
            #         batch['person1'], batch['person2'] = batch['person2'], batch['person1']

        return batch
    
    def get_item(self, idx):
        seq_data = self.dataset[idx]
        num_frames = len(seq_data['motion_p1']['transl'])
        gender = {}
        gender['person1'] = seq_data['motion_p1']['gender']
        gender['person2'] = seq_data['motion_p2']['gender']
        if 'text' in self.weight_scheme:
            start_frame = random.choices(range(num_frames - self.seq_length + 1), weights=seq_data['frame_weights'], k=1)[0]
        else:
            start_frame = random.randint(0, num_frames - self.seq_length)  # [0, num_frames - seq_length], right end inclusive
        primitive_data_list = []
        for frame_idx in range(start_frame, start_frame + self.seq_length - self.primitive_length, self.future_length):
            primitive_data = self.get_primitive(seq_data, frame_idx, frame_idx + self.primitive_length)
            primitive_data_list.append(primitive_data)
        
        gender_seq_texts = {key_type: [] for key_type in self.key_list}
        gender_seq_dict = None
        for primitive_idx in range(self.num_primitive):
            primitive_dict = {}
            for person in ['person1', 'person2']:
                primitive_dict[person] = {}
                primitive_dict[person]['gender'] = gender[person]
                for key in ['betas', 'transl', 'global_orient', 'body_pose', 'transf_rotmat', 'transf_transl', 'pelvis_delta', 'joints',
                            'jaw_pose', 'left_eye_pose', 'right_eye_pose', 'left_hand_pose', 'right_hand_pose']:
                    primitive_dict[person][key] = primitive_data_list[primitive_idx]['primitive_dict'][person][key]
                if self.padding:
                    primitive_dict[person]['primitive_padding_mask'] = primitive_data_list[primitive_idx]['primitive_dict'][person]['primitive_padding_mask']
            primitive_texts = {}
            for key_type in self.text_key_list:
                primitive_texts[key_type] = primitive_data_list[primitive_idx]['text_'+key_type]
                gender_seq_texts[key_type].append(primitive_texts[key_type])
            
            if gender_seq_dict is None:
                gender_seq_dict = primitive_dict
            else:
                for person in ['person1', 'person2']:
                    for key in ['betas', 'transl', 'global_orient', 'body_pose', 'transf_rotmat', 'transf_transl', 'pelvis_delta', 'joints',
                                'jaw_pose', 'left_eye_pose', 'right_eye_pose', 'left_hand_pose', 'right_hand_pose']:
                        gender_seq_dict[person][key] = torch.cat([gender_seq_dict[person][key], primitive_dict[person][key]], dim=0)
                    if self.padding:
                        gender_seq_dict[person]['primitive_padding_mask'] = torch.cat([gender_seq_dict[person]['primitive_padding_mask'],
                                                                                           primitive_dict[person]['primitive_padding_mask']], dim=0)

        canonicalized_primitive_dict = {}
        if self.mode == 'merged':
            transf_rotmat, transf_transl = {}, {}
        for person in ['person1', 'person2']:
            gender_seq_dict[person] = tensor_dict_to_device(gender_seq_dict[person], self.device)
            if self.mode == 'merged':
                transf_rotmat[person], transf_transl[person], canonicalized_primitive_dict[person] = self.primitive_utility.canonicalize(copy.deepcopy(gender_seq_dict[person]), use_predicted_joints=True)            
            else:
                _, _, canonicalized_primitive_dict[person] = self.primitive_utility.canonicalize(copy.deepcopy(gender_seq_dict[person]), use_predicted_joints=True)          
        
        if self.mode == 'merged':
            # rel_pose
            rel_rotmat, rel_transl, rel_pose = {}, {}, {}
            rel_rotmat['b2a'], rel_transl['b2a'] = self.primitive_utility.compute_rel_transform_B_in_A(transf_rotmat['person1'], transf_transl['person1'], transf_rotmat['person2'], transf_transl['person2'])
            rel_rotmat['a2b'], rel_transl['a2b'] = self.primitive_utility.compute_rel_transform_B_in_A(transf_rotmat['person2'], transf_transl['person2'], transf_rotmat['person1'], transf_transl['person1'])
            for rel in ['b2a', 'a2b']:
                rel_rotmat[rel] = transforms.matrix_to_rotation_6d(rel_rotmat[rel])
                rel_pose[rel] = torch.cat([rel_rotmat[rel], rel_transl[rel].squeeze(1)], dim=-1)  # [B*num_mp, 6+3]
        
        # calc features
        feature_dict = {}
        for person in ['person1', 'person2']:
            feature_dict[person] = self.primitive_utility.calc_features(canonicalized_primitive_dict[person], use_predicted_joints=True)
            if self.primitive_utility.feature_dim == 276:
                feature_dict[person]['transl'] = feature_dict[person]['transl'][:, :-1, :]  # [B*num_mp, T, 3]
                feature_dict[person]['poses_6d'] = feature_dict[person]['poses_6d'][:, :-1, :]  # [B*num_mp, T, 66]
                feature_dict[person]['joints'] = feature_dict[person]['joints'][:, :-1, :]  # [B*num_mp, T, 22 * 3]
            elif self.primitive_utility.feature_dim == 56 * 6:
                feature_dict[person]['transl'] = feature_dict[person]['transl'][:, :-1, :]  # [B*num_mp, T, 3]
                feature_dict[person]['body_pose'] = feature_dict[person]['body_pose'][:, :-1, :]  # [B*num_mp, T, 55*6]
            elif self.primitive_utility.feature_dim == 55 * 12:
                feature_dict[person]['joints'] = feature_dict[person]['joints'][:, :-1, :]  # [B*num_mp, T, 55 * 3]
                feature_dict[person]['body_pose'] = feature_dict[person]['body_pose'][:, :-1, :]  # [B*num_mp, T, 55*6]

        
        data_batch = {} if self.mode == 'merged' else []
        if self.mode == 'merged':
            for person in ['person1', 'person2']:
                data_batch[person] = []
                motion_tensor_normalized = self.normalize(self.dict_to_tensor(feature_dict[person]))     # [B*num_mp, T, D]
                motion_tensor_normalized = motion_tensor_normalized.permute(0, 2, 1).unsqueeze(2)           # [B*num_mp, D, 1, T]
                history_mask = torch.zeros_like(motion_tensor_normalized, dtype=torch.bool, device=self.device)
                history_mask[..., :self.cfg.history_length] = True
                history_motion = torch.zeros_like(motion_tensor_normalized, dtype=torch.float32, device=self.device)
                history_motion[..., :self.cfg.history_length] = motion_tensor_normalized[..., :self.cfg.history_length]

                for primitive_idx in range(self.num_primitive):
                    start_idx = primitive_idx 
                    end_idx = primitive_idx + 1
                    if self.use_indi_text:
                        primitive_texts = gender_seq_texts[person][start_idx:end_idx]
                        if self.load_text_embedding:
                            unseen_texts = [text for text in primitive_texts if text not in self.text_embedding_dict[person]]
                            if len(unseen_texts) > 0:
                                self.update_text_embedding_dict(unseen_texts, person, text_sep=self.text_sep, max_segs=self.max_segs)
                            text_embedding = torch.stack([self.text_embedding_dict[person][text] for text in primitive_texts], dim=0)  # [B, 512]
                            if self.text_sep:
                                text_mask = torch.stack([self.text_mask_dict[person][text] for text in primitive_texts], dim=0)  # [B, max_segs]
                            else:
                                text_mask = None
            
                    data_batch[person].append({
                            'texts': primitive_texts,
                            'gender': [gender_seq_dict[person]['gender']],
                            'betas': gender_seq_dict[person]['betas'][start_idx:end_idx, :-1, :10],
                            'motion_tensor_normalized': motion_tensor_normalized[start_idx:end_idx, ...], # [B, D, 1, T]
                            'history_motion': history_motion[start_idx:end_idx, ...],
                            'history_mask': history_mask[start_idx:end_idx, ...],
                            'transf_rotmat': transf_rotmat[person][start_idx:end_idx, ...],
                            'transf_transl': transf_transl[person][start_idx:end_idx, ...],
                            'history_length': self.cfg.history_length,
                            'future_length': self.cfg.future_length,
                        })
                    if self.use_indi_text:
                        data_batch[person][-1]['texts'] = primitive_texts
                        if self.load_text_embedding:
                            data_batch[person][-1]['text_embedding'] = text_embedding
                            data_batch[person][-1]['text_mask'] = text_mask
                    if self.padding:
                        data_batch[person][-1]['primitive_padding_mask'] = gender_seq_dict[person]['primitive_padding_mask'][start_idx:end_idx, ...]
            data_batch['interaction'] = []
            for primitive_idx in range(self.num_primitive):        
                start_idx = primitive_idx
                end_idx = (primitive_idx + 1)
                primitive_texts = gender_seq_texts['interaction'][start_idx:end_idx]
                if self.load_text_embedding:
                    unseen_texts = [text for text in primitive_texts if text not in self.text_embedding_dict['interaction']]
                    if len(unseen_texts) > 0:
                        self.update_text_embedding_dict(unseen_texts, 'interaction', text_sep=self.text_sep, max_segs=self.max_segs)
                    text_embedding = torch.stack([self.text_embedding_dict['interaction'][text] for text in primitive_texts], dim=0)  # [B, 512]
                    if self.text_sep:
                        text_mask = torch.stack([self.text_mask_dict['interaction'][text] for text in primitive_texts], dim=0)  # [B, max_segs]
                    else:
                        text_mask = None
                data_batch['interaction'].append({
                        'texts': primitive_texts,
                        'rel_pose_b2a': rel_pose['b2a'][start_idx:end_idx],
                        'rel_pose_a2b': rel_pose['a2b'][start_idx:end_idx],
                    })
                if self.load_text_embedding:
                    data_batch['interaction'][-1]['text_embedding'] = text_embedding
                    data_batch['interaction'][-1]['text_mask'] = text_mask
            return data_batch
        else:
            for person in ['person1', 'person2']:
                motion_tensor_normalized = self.normalize(self.dict_to_tensor(feature_dict[person]))     # [num_mp, T, D]
                motion_tensor_normalized = motion_tensor_normalized.permute(0, 2, 1).unsqueeze(2)           # [num_mp, D, 1, T]
                history_mask = torch.zeros_like(motion_tensor_normalized, dtype=torch.bool, device=self.device)
                history_mask[..., :self.cfg.history_length] = True
                history_motion = torch.zeros_like(motion_tensor_normalized, dtype=torch.float32, device=self.device)
                history_motion[..., :self.cfg.history_length] = motion_tensor_normalized[..., :self.cfg.history_length]

                for primitive_idx in range(self.num_primitive):
                    start_idx = primitive_idx
                    end_idx = primitive_idx + 1
                    if self.use_indi_text:
                        primitive_texts = gender_seq_texts[person][start_idx:end_idx]
                        if self.load_text_embedding:
                            unseen_texts = [text for text in primitive_texts if text not in self.text_embedding_dict[person]]
                            if len(unseen_texts) > 0:
                                new_text_embeddings = encode_text(self.clip_model, unseen_texts)
                                for idx, text in enumerate(unseen_texts):
                                    self.text_embedding_dict[person][text] = new_text_embeddings[idx]
                            text_embedding = torch.stack([self.text_embedding_dict[person][text] for text in primitive_texts], dim=0)  # [B, 512]
                            if self.text_sep:
                                text_mask = torch.stack([self.text_mask_dict[person][text] for text in primitive_texts], dim=0)  # [B, max_segs]
                            else:
                                text_mask = None
                    data_batch.append(
                        {
                            'gender': [gender_seq_dict[person]['gender']],
                            'betas': gender_seq_dict[person]['betas'][start_idx:end_idx, :-1, :10],
                            'motion_tensor_normalized': motion_tensor_normalized[start_idx:end_idx, ...], # [1, D, 1, T]
                            'history_motion': history_motion[start_idx:end_idx, ...],
                            'history_mask': history_mask[start_idx:end_idx, ...],
                            'history_length': self.cfg.history_length,
                            'future_length': self.cfg.future_length,
                        }
                    )
                    if self.use_indi_text:
                        data_batch[-1]['texts'] = primitive_texts
                        if self.load_text_embedding:
                            data_batch[-1]['text_embedding'] = text_embedding
                            data_batch[-1]['text_mask'] = text_mask
                    if self.padding:
                        data_batch[-1]['primitive_padding_mask'] = gender_seq_dict[person]['primitive_padding_mask'][start_idx:end_idx, ...]
            if random.random() < 0.5:
                return data_batch[:self.num_primitive]
            else:
                return data_batch[self.num_primitive:]


class InterXDatasetWPE(WeightedPrimitiveSequenceDataset):
    def __init__(self, dataset_name='interx',
                 dataset_path='./data/Inter-X/seq_data_single_interaction_fps30',
                 cfg_path='./config_files/config_hydra/motion_primitive/interx_h2_f8_r4.yaml',
                 split="train",
                 device='cuda',
                 weight_scheme='uniform',
                 prob_static=0.0,
                 enforce_gender=None,
                 enforce_zero_beta=None,
                 load_data=True,
                 text_tolerance=0.0,
                 body_type='smplx',
                 seed_only=False,
                 use_frame_weights=True,
                 mode='merged', # 'sep' or 'merged'
                 text_sep = False,
                 max_segs = 20,
                 motion_repr = {
                    'body_pose': 55 * 6,
                    'transl': 3,
                    'transl_delta': 3,
                },
                #  motion_repr = {
                #      'joints': 55 * 3,
                #      'joints_delta': 55 * 3,
                #      'body_pose': 55 * 6,
                # },
                 **kwargs):
        self.dataset_name = dataset_name
        self.dataset_path = dataset_path
        self.split = split
        self.device = device
        self.weight_scheme = weight_scheme
        self.prob_static = prob_static
        self.enforce_gender = enforce_gender
        self.enforce_zero_beta = enforce_zero_beta
        print('enforce_gender: ', enforce_gender)
        print('enforce_zero_beta: ', enforce_zero_beta)
        
        self.text_tolerance = text_tolerance
        self.seed_only = seed_only
        self.mode = mode
        self.text_sep = text_sep
        self.max_segs = max_segs
        self.sep_mode = kwargs.get('sep_mode', 0)
        self.padding = kwargs.get('padding', False)
        self.key_list = ['person1', 'person2', 'interaction'] if self.mode=='merged' else ['person1', 'person2']
        self.feet_thre = 0.001
        self.n_joints = 55
        
        self.clip_version = kwargs.get('clip_version', 'ViT-B/32')
        self.load_text_embedding = kwargs.get('load_text_embedding', False)
        self.use_indi_text = kwargs.get('use_indi_text', False)
        self.text_key_list = ['person1', 'person2', 'interaction'] if self.use_indi_text else ['interaction']
        
        self.primitive_utility = PrimitiveUtility(device=self.device, body_type=body_type, motion_repr=motion_repr)
        self.motion_repr = self.primitive_utility.motion_repr

        # cfg_path = Path(dataset_path, 'config.yaml')
        with open(cfg_path, 'r') as f:
            self.cfg = OmegaConf.load(f)
        self.target_fps = self.cfg.fps
        # self.downsample_rate = 120 // self.target_fps
        self.history_length = self.cfg.history_length
        self.future_length = self.cfg.future_length
        self.primitive_length = self.history_length + self.future_length
        self.num_primitive = self.cfg.num_primitive
        if seed_only:
            assert self.num_primitive == 1
        self.seq_length = self.history_length + self.future_length * self.num_primitive + 1

        if load_data:
            with open(pjoin(dataset_path, f'{split}.pkl'), 'rb') as f:
                dataset = pickle.load(f)
            if not self.padding:
                dataset = [data for data in dataset if len(data['motion_p1']['trans']) >= self.seq_length]
            for data in tqdm(dataset, desc='Processing dataset'):
                if self.padding:
                    T = data['motion_p1']['trans'].shape[0]
                    if T < self.seq_length:
                        pad_len = self.seq_length - T
                        for person in ['motion_p1', 'motion_p2']:
                            for key in ['trans', 'poses', 'joints']:
                                last_frame = data[person][key][-1:]
                                padding = np.repeat(last_frame, pad_len, axis=0)
                                data[person][key] = np.concatenate([
                                    data[person][key],
                                    padding
                                ], axis=0)

                            # padding_mask
                            data[person]['padding_mask'] = np.concatenate([
                                np.zeros(T, dtype=np.bool_),
                                np.ones(pad_len, dtype=np.bool_)
                            ], axis=0)
                    else:
                        for person in ['motion_p1', 'motion_p2']:
                            data[person]['padding_mask'] = np.zeros(T, dtype=np.bool_)
                        
                def convert_motion(motion, gender, enforce_zero_beta):
                    betas = torch.from_numpy(motion['betas'].astype(np.float32))
                    if enforce_zero_beta:
                        betas = torch.zeros_like(betas)
                    transl = torch.from_numpy(motion['trans'].astype(np.float32))
                    poses = torch.from_numpy(motion['poses'].astype(np.float32))
                    global_orient, body_pose, jaw_pose, leye_pose, reye_pose, pose_lhand, pose_rhand = \
                        transforms.axis_angle_to_matrix(poses[:, :3]), \
                        transforms.axis_angle_to_matrix(poses[:, 3:66].reshape(-1, 21, 3)), \
                        transforms.axis_angle_to_matrix(poses[:, 66:69]), \
                        transforms.axis_angle_to_matrix(poses[:, 69:72]), \
                        transforms.axis_angle_to_matrix(poses[:, 72:75]), \
                        transforms.axis_angle_to_matrix(poses[:, 75:120].reshape(-1, 15, 3)), \
                        transforms.axis_angle_to_matrix(poses[:, 120:].reshape(-1, 15, 3))            
                    pelvis_delta = torch.from_numpy(motion['pelvis_delta'].astype(np.float32))              # [3]
                    joints = torch.from_numpy(motion['joints'].astype(np.float32))                          # [T, 55, 3]
                    result = {
                        'gender': gender,
                        'betas': betas,
                        'transl': transl,
                        'global_orient': global_orient,
                        'body_pose': body_pose,
                        'jaw_pose': jaw_pose,
                        'left_eye_pose': leye_pose,
                        'right_eye_pose': reye_pose,
                        'left_hand_pose': pose_lhand,
                        'right_hand_pose': pose_rhand,
                        'pelvis_delta': pelvis_delta,
                        'joints': joints,
                    }
                    if self.padding:
                        result['padding_mask'] = motion['padding_mask']
                    return result
                gender_p1 = self.enforce_gender if self.enforce_gender is not None else data['motion_p1']['gender']
                gender_p2 = self.enforce_gender if self.enforce_gender is not None else data['motion_p2']['gender']
                data['motion_p1'] = convert_motion(data['motion_p1'], gender_p1, self.enforce_zero_beta)
                data['motion_p2'] = convert_motion(data['motion_p2'], gender_p2, self.enforce_zero_beta)
            
            print('num of sequences: ', len(dataset))
            
            # assign sampling weights to each sequence
            for data in dataset:
                if 'uniform' in weight_scheme:
                    data['weight'] = 1.0
                elif 'length' in weight_scheme:
                    if self.padding and len(data['motion_p1']['transl'])==self.seq_length:
                        data['weight'] = len(data['motion_p1']['transl'])-sum(data['motion_p1']['padding_mask'])
                    else:
                        data['weight'] = len(data['motion_p1']['transl'])
            print('finish first assigning seq weights')

            # overfit using one sequence
            # if 'overfit' in weight_scheme:
            #     seq_id = int(weight_scheme.split('overfit:')[-1].split('_')[0])
            #     for idx, data in enumerate(dataset):
            #         if idx == seq_id:
            #             data['weight'] = 1.0
            #         else:
            #             data['weight'] = 0.0

            seq_weights = np.array([data['weight'] for data in dataset])
            seq_weights = seq_weights / seq_weights.sum()

            self.dataset = dataset
            self.seq_weights = seq_weights
        self._curr_test_index = 0
        
        # load or calc mean and std
        self.tensor_mean_device_dict = {}
        file_name = f'mean_std_h{self.history_length}_f{self.future_length}'
        
        suffix = '_padding' if self.padding else ''
        mean_std_path = Path(dataset_path, f'{file_name}{suffix}.pkl')
        
        if mean_std_path.exists():
            print(f'loading mean and std from {mean_std_path}')
            with open(mean_std_path, 'rb') as f:
                self.tensor_mean, self.tensor_std = pickle.load(f)  # [1, 1, D]
        else:
            assert self.split == 'train'
            print('calculating mean and std using train split')
            result = self.calc_mean_std()
            self.tensor_mean, self.tensor_std = result

            with open(mean_std_path, 'wb') as f:
                pickle.dump((self.tensor_mean.detach().cpu(), self.tensor_std.detach().cpu()), f)
            
    
        # load clip model, get train text embeddings
        if self.load_text_embedding:
            self.load_and_freeze_clip(clip_version=self.clip_version, device=self.device)
            self.dim_embed_text = self.clip_model.ln_final.normalized_shape[0]
            suffix = '' if self.clip_version == 'ViT-B/32' else f"_{self.clip_version.replace('/', '')}"
            self.embedding_path = {}
            embedding_path = {}
            for key_type in self.text_key_list:
                if text_sep:
                    self.embedding_path[key_type] = embedding_path[key_type] = Path(dataset_path, f'{split}_{key_type}_text_embedding_dict_textsep_sepmode{self.sep_mode}{suffix}.pkl')
                else:
                    self.embedding_path[key_type] = embedding_path[key_type] = Path(dataset_path, f'{split}_{key_type}_text_embedding_dict{suffix}.pkl')
            self.text_embedding_dict = {}
            if text_sep:
                self.text_mask_dict = {}
        
            for key_type in self.text_key_list:
                if embedding_path[key_type].exists():
                    print(f"Loading text_{key_type} embeddings from {embedding_path[key_type]}!")
                    with open(embedding_path[key_type], 'rb') as f:
                        self.text_embedding_dict[key_type] = pickle.load(f)
                    if text_sep:
                        with open(Path(str(embedding_path[key_type]).replace('text_embedding_dict', 'text_mask_dict')), 'rb') as f:
                            self.text_mask_dict[key_type] = pickle.load(f)
                else:
                    print('Calculating text embeddings')
                    raw_texts = []
                    for data in self.dataset:
                        if f'frame_labels_{key_type}' in data:
                            raw_texts.extend([seg['proc_label'] for seg in data['frame_labels_' + key_type]])

                    raw_texts = list(set(raw_texts))
                    num_texts = len(raw_texts)
                    print(f'num of unique texts_{key_type}: ', len(raw_texts))
                        
                    # get text embeddings by batch
                    text_embeddings = []
                    text_mask = []
                    batch_start_idx = 0
                    while batch_start_idx < num_texts:
                        batch_end_idx = min(batch_start_idx + 256, num_texts)
                        text_embeddings_temp = self.encode_text(raw_texts[batch_start_idx:batch_end_idx], text_sep=text_sep, max_segs=max_segs, sep_mode=self.sep_mode)
                        if text_sep:
                            text_embeddings.append(text_embeddings_temp[0])
                            text_mask.append(text_embeddings_temp[1])
                        else:
                            text_embeddings.append(text_embeddings_temp)
                        batch_start_idx = batch_end_idx
                    text_embeddings = torch.cat(text_embeddings, dim=0).detach().cpu().numpy()
                
                    self.text_embedding_dict[key_type] = {raw_texts[idx]: text_embeddings[idx] for idx in range(num_texts)}
                    if text_sep:
                        self.text_embedding_dict[key_type][''] = np.zeros((self.max_segs, self.dim_embed_text)).astype(np.float32)
                    else:
                        self.text_embedding_dict[key_type][''] = np.zeros(self.dim_embed_text).astype(np.float32)  # for empty text have zero embedding, compatible with mdm text masking
                    with open(embedding_path[key_type], 'wb') as f:
                        pickle.dump(self.text_embedding_dict[key_type], f)
                    if text_sep:
                        text_mask = torch.cat(text_mask, dim=0).detach().cpu().numpy()
                        self.text_mask_dict[key_type] = {raw_texts[idx]: text_mask[idx] for idx in range(num_texts)}
                        self.text_mask_dict[key_type][''] = np.zeros(max_segs).astype(np.bool_)  # for empty text have zero embedding, compatible with mdm text masking
                        with open(Path(str(embedding_path[key_type]).replace('text_embedding_dict', 'text_mask_dict')), 'wb') as f:
                            pickle.dump(self.text_mask_dict[key_type], f)
                
                for key in self.text_embedding_dict[key_type]:
                    self.text_embedding_dict[key_type][key] = torch.from_numpy(self.text_embedding_dict[key_type][key]).to(dtype=torch.float32, device=self.device)
                    if text_sep:
                        self.text_mask_dict[key_type][key] = torch.from_numpy(self.text_mask_dict[key_type][key]).to(dtype=torch.bool, device=self.device)

    def load_and_freeze_clip(self, clip_version, device='cpu'):
        self.clip_model, _= clip.load(clip_version, device=device,
                                                jit=False)  # Must set jit=False for training
        clip.model.convert_weights(self.clip_model)  # Actually this line is unnecessary since clip by default already on float16

        # Freeze CLIP weights
        self.clip_model.eval()
        for p in self.clip_model.parameters():
            p.requires_grad = False

    def encode_text(self, raw_text, force_empty_zero=True, text_sep=False, max_segs = 20, sep_mode=0):
        import pandas as pd
        device = next(self.clip_model.parameters()).device
        embed_dim = self.dim_embed_text
        batch_size = len(raw_text)

        if not text_sep:
            with torch.no_grad():
                texts = clip.tokenize(raw_text, truncate=True).to(device)  # [B, context_length]
                text_embedding = self.clip_model.encode_text(texts).float()  # [B, 512]
                if force_empty_zero:
                    empty_text = [t == '' for t in raw_text]
                    text_embedding[empty_text, :] = 0
                return text_embedding
                
        raw_series = pd.Series(raw_text).str.strip().str.rstrip('.')
        if sep_mode == 0:
            split_df = raw_series.str.split(r'[,.]', n=max_segs - 1, expand=True)
        elif sep_mode == 1:
            split_df = raw_series.str.split(r'\band\b|\bwhile\b|,|\.', n=max_segs - 1, expand=True)
        split_df = split_df.fillna('').astype(str).applymap(str.strip)

        split_df = split_df.reindex(columns=range(max_segs), fill_value='')
        
        segs_matrix = split_df.values
        segs_flat = segs_matrix.reshape(-1).tolist()

        text_mask = (segs_matrix == '').astype(bool)
        text_mask = torch.tensor(text_mask, dtype=torch.bool, device=device)

        tokenized = clip.tokenize(segs_flat, truncate=True).to(device)  # [B*max_segs, context_length]
        text_embedding = self.clip_model.encode_text(tokenized).float()      # [B*max_segs, 512]
        text_embedding = text_embedding.view(batch_size, max_segs, embed_dim)  # [B, max_segs, 512]

        if force_empty_zero:
            text_embedding[text_mask] = 0

        return text_embedding, text_mask
            
    def get_batch_idx(self, batch_size=8):
        if self.split == 'test':
            start_idx = self._curr_test_index
            end_idx = start_idx + batch_size
            batch_idx = np.arange(start_idx, min(end_idx, len(self.dataset)))
            self._curr_test_index = end_idx if end_idx < len(self.dataset) else 0
            return batch_idx
        else:
            batch_idx = np.random.choice(len(self.dataset), size=batch_size, replace=True, p=self.seq_weights)
            return batch_idx

    def update_text_embedding_dict(self, new_texts, key_type, text_sep=False, max_segs=20):
        new_text_embeddings = self.encode_text(new_texts, text_sep=text_sep, max_segs=max_segs, sep_mode=self.sep_mode)
        for idx, text in enumerate(new_texts):
            if text_sep:
                self.text_embedding_dict[key_type][text] = new_text_embeddings[0][idx]
                self.text_mask_dict[key_type][text] = new_text_embeddings[1][idx]
            else:
                self.text_embedding_dict[key_type][text] = new_text_embeddings[idx]

    def calc_mean_std(self, batch_size=512):
        if self.future_length == 1:
            batch_size = min(batch_size, 64)
        all_mp_data = []
        for seq_data in self.dataset:
            motion_data_p1 = seq_data['motion_p1']
            num_frames = motion_data_p1['body_pose'].shape[0]
            primitive_data_list = []
            for start_frame in range(0, num_frames - self.primitive_length, self.future_length):
                end_frame = start_frame + self.primitive_length
                primitive_data_list.append(self.get_primitive(seq_data, start_frame, end_frame, skip_text=True))
                
            primitive_dict = {}
            for person in ['person1', 'person2']:
                primitive_dict[person] = {}
                primitive_dict[person]['gender'] = {primitive_data_list[0]['primitive_dict'][person]['gender']}
                for key in ['betas', 'transl', 'global_orient', 'body_pose', 'jaw_pose', 'left_eye_pose', 'right_eye_pose', 'left_hand_pose', 'right_hand_pose',
                            'transf_rotmat', 'transf_transl', 'pelvis_delta', 'joints']:
                    primitive_dict[person][key] = torch.cat([data['primitive_dict'][person][key] for data in primitive_data_list], dim=0)
                if self.padding:
                    primitive_dict[person]['primitive_padding_mask'] = torch.cat([data['primitive_dict'][person]['primitive_padding_mask'] for data in primitive_data_list], dim=0)
                primitive_dict[person] = tensor_dict_to_device(primitive_dict[person], self.device)

            # split primitive_dict into batches
            batch_start_idx = 0
            while batch_start_idx < len(primitive_dict['person1']['body_pose']):
                batch_primitive_dict = {}
                canonicalized_primitive_dict = {}
                batch_end_idx = min(batch_start_idx + batch_size, len(primitive_dict['person1']['body_pose']))
                for person in ['person1', 'person2']:
                    batch_primitive_dict[person] = {}
                    batch_primitive_dict[person] = {key: primitive_dict[person][key][batch_start_idx:batch_end_idx] for key in ['betas', 'transl', 'global_orient', 'body_pose', 'jaw_pose', 'left_eye_pose', 'right_eye_pose', 'left_hand_pose', 'right_hand_pose', 'transf_rotmat', 'transf_transl', 'pelvis_delta', 'joints']}
                    batch_primitive_dict[person]['gender'] = primitive_dict[person]['gender']
                    _, _, canonicalized_primitive_dict[person] = self.primitive_utility.canonicalize(copy.deepcopy(batch_primitive_dict[person]), use_predicted_joints=True)
                
                feature_dict = {}
                for person in ['person1', 'person2']:
                    feature_dict[person] = self.primitive_utility.calc_features(canonicalized_primitive_dict[person], use_predicted_joints=True)
                    if self.primitive_utility.feature_dim == 276:
                        feature_dict[person]['transl'] = feature_dict[person]['transl'][:, :-1, :]  # [B*num_mp, T, 3]
                        feature_dict[person]['poses_6d'] = feature_dict[person]['poses_6d'][:, :-1, :]  # [B*num_mp, T, 66]
                        feature_dict[person]['joints'] = feature_dict[person]['joints'][:, :-1, :]  # [B*num_mp, T, 22 * 3]
                    elif self.primitive_utility.feature_dim == 56 * 6:
                        feature_dict[person]['transl'] = feature_dict[person]['transl'][:, :-1, :]  # [B*num_mp, T, 3]
                        feature_dict[person]['body_pose'] = feature_dict[person]['body_pose'][:, :-1, :]  # [B*num_mp, T, 55*6]
                    elif self.primitive_utility.feature_dim == 55 * 12:
                        feature_dict[person]['joints'] = feature_dict[person]['joints'][:, :-1, :]  # [B*num_mp, T, 55 * 3]
                        feature_dict[person]['body_pose'] = feature_dict[person]['body_pose'][:, :-1, :]  # [B*num_mp, T, 55*6]
                    motion_tensor = self.dict_to_tensor(feature_dict[person]).detach().cpu()    # [num_primitive, T, D]
                    if self.padding:
                        mask_slice = primitive_dict[person]['primitive_padding_mask'][batch_start_idx:batch_end_idx, -1]  # [B]
                        valid_indices = torch.nonzero(~mask_slice, as_tuple=True)[0].detach().cpu()  # select valid
                        motion_tensor = motion_tensor[valid_indices]
                    all_mp_data.append(motion_tensor)
                batch_start_idx = batch_end_idx

        all_mp_data = torch.cat(all_mp_data, dim=0)                 # [2*N, T, D]
        tensor_mean = all_mp_data.mean(dim=[0, 1], keepdim=True)    # [1, 1, D]
        tensor_std = all_mp_data.std(dim=[0, 1], keepdim=True)      # [1, 1, D]
        return tensor_mean.to(self.device), tensor_std.to(self.device)

    def get_primitive(self, seq_data, start_frame, end_frame, skip_text=False):
        """end_frame included"""
        primitive_dict = {}
        for person, motion_data in zip(['person1', 'person2'], [seq_data['motion_p1'], seq_data['motion_p2']]):
            primitive_dict[person] = {
                'gender': motion_data['gender'],
                'betas': motion_data['betas'].expand(1, self.primitive_length + 1, 10),
                'transl': motion_data['transl'][start_frame:end_frame + 1].unsqueeze(0),
                'global_orient': motion_data['global_orient'][start_frame:end_frame + 1].unsqueeze(0),
                'body_pose': motion_data['body_pose'][start_frame:end_frame + 1].unsqueeze(0),
                'jaw_pose': motion_data['jaw_pose'][start_frame:end_frame + 1].unsqueeze(0),
                'left_eye_pose': motion_data['left_eye_pose'][start_frame:end_frame + 1].unsqueeze(0),
                'right_eye_pose': motion_data['right_eye_pose'][start_frame:end_frame + 1].unsqueeze(0),
                'left_hand_pose': motion_data['left_hand_pose'][start_frame:end_frame + 1].unsqueeze(0),
                'right_hand_pose': motion_data['right_hand_pose'][start_frame:end_frame + 1].unsqueeze(0),
                'pelvis_delta': motion_data['pelvis_delta'].unsqueeze(0),
                'joints': motion_data['joints'][start_frame:end_frame + 1].unsqueeze(0),
                'transf_rotmat': torch.eye(3).unsqueeze(0),
                'transf_transl': torch.zeros(1, 1, 3),
            }
            if self.padding:
                padding_mask_full = seq_data[f'motion_p{person[-1]}']['padding_mask'][start_frame:end_frame + 1]  # shape [T+1]
                history_mask = torch.tensor(padding_mask_full[:self.history_length], dtype=torch.bool)
                future_mask = padding_mask_full[self.history_length:-1]
                future_flag = torch.tensor(future_mask.any(), dtype=torch.bool)
                primitive_dict[person]['primitive_padding_mask'] = torch.cat([history_mask, future_flag.unsqueeze(0)], dim=0).unsqueeze(0) # (1, history_length + 1)

        texts = {key: [] for key in self.text_key_list}
        for key_type in self.text_key_list:
            if not skip_text and f'frame_labels_{key_type}' in seq_data:
                future_start = (start_frame + self.history_length) / self.target_fps
                future_end = (start_frame + self.history_length + self.future_length - 1) / self.target_fps
                # print('text tolerance: ', self.text_tolerance)
                for seg in seq_data[f'frame_labels_{key_type}']:
                    if have_overlap([seg['start_t'], seg['end_t']], [future_start - self.text_tolerance, future_end + self.text_tolerance]):
                        texts[key_type].append(seg['proc_label'])

        output = {}
        for key_type in self.text_key_list:
            output['text_'+key_type] = random.choice(texts[key_type]) if len(texts[key_type]) > 0 else ''
        output['primitive_dict'] = primitive_dict
        return output

    def get_batch(self, batch_size=8):
        self.time = time.time()
        seq_list = []
        batch_idx = self.get_batch_idx(batch_size)
        add_key_list = ['gender']
        cat_key_list = ['betas', 'motion_tensor_normalized', 'history_motion', 'history_mask', 'transf_rotmat', 'transf_transl', 'start_frame', 'total_frames']
        if self.padding:
            cat_key_list.append('primitive_padding_mask')
        if self.use_indi_text:
            add_key_list.append('texts')
        if self.load_text_embedding:
            cat_key_list.append('text_embedding')
            if self.text_sep:
                cat_key_list.append('text_mask')
        
        for seq_idx in batch_idx:
            seq_data = dict(self.dataset[seq_idx])
            # exchange person1 and person2
            if random.random() < 0.5:
                seq_data['motion_p1'], seq_data['motion_p2'] = seq_data['motion_p2'], seq_data['motion_p1']
                if self.use_indi_text:
                    seq_data['frame_labels_person1'], seq_data['frame_labels_person2'] = seq_data['frame_labels_person2'], seq_data['frame_labels_person1']
            num_frames = len(seq_data['motion_p1']['transl'])
            if 'text' in self.weight_scheme:
                start_frame = random.choices(range(num_frames - self.seq_length + 1), weights=seq_data['frame_weights'], k=1)[0]
            else:
                start_frame = random.randint(0, num_frames - self.seq_length)  # [0, num_frames - seq_length], right end inclusive
            primitive_data_list = []
            for frame_idx in range(start_frame, start_frame + self.seq_length - self.primitive_length, self.future_length):
                primitive_data = self.get_primitive(seq_data, frame_idx, frame_idx + self.primitive_length)
                for person in ['person1', 'person2']:
                    primitive_data['primitive_dict'][person]['start_frame'] = torch.tensor(frame_idx).view(1)
                    primitive_data['primitive_dict'][person]['total_frames'] = torch.tensor(num_frames).view(1)
                primitive_data_list.append(primitive_data)
            seq_list.append(primitive_data_list)

        # sort batch by gender
        batch = None
        for gender in ['female', 'male', 'neutral']:
            gender_idx = [idx for idx in range(len(seq_list)) if seq_list[idx][0]['primitive_dict']['person1']['gender'] == gender]
            if len(gender_idx) == 0:
                continue
            gender_seq_list = [seq_list[i] for i in gender_idx]
            gender_batch_size = len(gender_idx)
            gender_batch = {} if self.mode == 'merged' else []
            
            gender_seq_texts = {key_type: None for key_type in self.text_key_list}
            gender_seq_dict = None
            for primitive_idx in range(self.num_primitive):
                primitive_dict = {}
                for person in ['person1', 'person2']:
                    primitive_dict[person] = {}
                    primitive_dict[person]['gender'] = gender
                    for key in ['betas', 'transl', 'global_orient', 'body_pose', 'transf_rotmat', 'transf_transl', 'pelvis_delta', 'joints', 
                                'jaw_pose', 'left_eye_pose', 'right_eye_pose', 'left_hand_pose', 'right_hand_pose', 'start_frame', 'total_frames']:
                        primitive_dict[person][key] = torch.cat([mp_seq[primitive_idx]['primitive_dict'][person][key] for mp_seq in gender_seq_list], dim=0)
                    if self.padding:
                        primitive_dict[person]['primitive_padding_mask'] = torch.cat([mp_seq[primitive_idx]['primitive_dict'][person]['primitive_padding_mask'] for mp_seq in gender_seq_list], dim=0)
                primitive_texts = {}
                for key_type in self.text_key_list:
                    primitive_texts[key_type] = [mp_seq[primitive_idx]['text_'+key_type] for mp_seq in gender_seq_list]
                    gender_seq_texts[key_type] = primitive_texts[key_type] if gender_seq_texts[key_type] is None else gender_seq_texts[key_type] + primitive_texts[key_type]
                
                if gender_seq_dict is None:
                    gender_seq_dict = primitive_dict
                else:
                    for person in ['person1', 'person2']:
                        for key in ['betas', 'transl', 'global_orient', 'body_pose', 'transf_rotmat', 'transf_transl', 'pelvis_delta', 'joints', 
                                    'jaw_pose', 'left_eye_pose', 'right_eye_pose', 'left_hand_pose', 'right_hand_pose', 'start_frame', 'total_frames']:
                            gender_seq_dict[person][key] = torch.cat([gender_seq_dict[person][key], primitive_dict[person][key]], dim=0)
                        if self.padding:
                            gender_seq_dict[person]['primitive_padding_mask'] = torch.cat([gender_seq_dict[person]['primitive_padding_mask'], 
                                                                                           primitive_dict[person]['primitive_padding_mask']], dim=0)

            canonicalized_primitive_dict = {}
            transf_rotmat, transf_transl = {}, {}
            for person in ['person1', 'person2']:
                gender_seq_dict[person] = tensor_dict_to_device(gender_seq_dict[person], self.device)
                transf_rotmat[person], transf_transl[person], canonicalized_primitive_dict[person] = self.primitive_utility.canonicalize(copy.deepcopy(gender_seq_dict[person]), use_predicted_joints=True)            
            if self.mode == 'merged':
                # rel_pose
                rel_rotmat, rel_transl, rel_pose = {}, {}, {}
                rel_rotmat['b2a'], rel_transl['b2a'] = self.primitive_utility.compute_rel_transform_B_in_A(transf_rotmat['person1'], transf_transl['person1'], transf_rotmat['person2'], transf_transl['person2'])
                rel_rotmat['a2b'], rel_transl['a2b'] = self.primitive_utility.compute_rel_transform_B_in_A(transf_rotmat['person2'], transf_transl['person2'], transf_rotmat['person1'], transf_transl['person1'])
                for rel in ['b2a', 'a2b']:
                    rel_rotmat[rel] = transforms.matrix_to_rotation_6d(rel_rotmat[rel])
                    rel_pose[rel] = torch.cat([rel_rotmat[rel], rel_transl[rel].squeeze(1)], dim=-1)  # [B*num_mp, 6+3]
                        
            # calc features
            feature_dict = {}
            for person in ['person1', 'person2']:
                feature_dict[person] = self.primitive_utility.calc_features(canonicalized_primitive_dict[person], use_predicted_joints=True)
                if self.primitive_utility.feature_dim == 276:
                    feature_dict[person]['transl'] = feature_dict[person]['transl'][:, :-1, :]  # [B*num_mp, T, 3]
                    feature_dict[person]['poses_6d'] = feature_dict[person]['poses_6d'][:, :-1, :]  # [B*num_mp, T, 66]
                    feature_dict[person]['joints'] = feature_dict[person]['joints'][:, :-1, :]  # [B*num_mp, T, 22 * 3]
                elif self.primitive_utility.feature_dim == 56 * 6:
                    feature_dict[person]['transl'] = feature_dict[person]['transl'][:, :-1, :]  # [B*num_mp, T, 3]
                    feature_dict[person]['body_pose'] = feature_dict[person]['body_pose'][:, :-1, :]  # [B*num_mp, T, 55*6]
                elif self.primitive_utility.feature_dim == 55 * 12:
                    feature_dict[person]['joints'] = feature_dict[person]['joints'][:, :-1, :]  # [B*num_mp, T, 55 * 3]
                    feature_dict[person]['body_pose'] = feature_dict[person]['body_pose'][:, :-1, :]  # [B*num_mp, T, 55*6]   
                            
            if self.mode == 'merged':
                for person in ['person1', 'person2']:
                    gender_batch[person] = []
                    motion_tensor_normalized = self.normalize(self.dict_to_tensor(feature_dict[person]))     # [B*num_mp, T, D]
                    motion_tensor_normalized = motion_tensor_normalized.permute(0, 2, 1).unsqueeze(2)           # [B*num_mp, D, 1, T]
                    history_mask = torch.zeros_like(motion_tensor_normalized, dtype=torch.bool, device=self.device)
                    history_mask[..., :self.cfg.history_length] = True
                    history_motion = torch.zeros_like(motion_tensor_normalized, dtype=torch.float32, device=self.device)
                    history_motion[..., :self.cfg.history_length] = motion_tensor_normalized[..., :self.cfg.history_length]

                    for primitive_idx in range(self.num_primitive):
                        start_idx = primitive_idx * gender_batch_size
                        end_idx = (primitive_idx + 1) * gender_batch_size
                        if self.use_indi_text:
                            primitive_texts = gender_seq_texts[person][start_idx:end_idx]
                            if self.load_text_embedding:
                                unseen_texts = [text for text in primitive_texts if text not in self.text_embedding_dict[person]]
                            if len(unseen_texts) > 0:
                                self.update_text_embedding_dict(unseen_texts, person, text_sep=self.text_sep, max_segs=self.max_segs)
                            text_embedding = torch.stack([self.text_embedding_dict[person][text] for text in primitive_texts], dim=0)  # [B, 512]
                            if self.text_sep:
                                text_mask = torch.stack([self.text_mask_dict[person][text] for text in primitive_texts], dim=0)  # [B, max_segs]
                            else:
                                text_mask = None
               
                        gender_batch[person].append({
                                'gender': [gender_seq_dict[person]['gender']] * gender_batch_size,
                                'betas': gender_seq_dict[person]['betas'][start_idx:end_idx, :-1, :10],
                                'motion_tensor_normalized': motion_tensor_normalized[start_idx:end_idx, ...], # [B, D, 1, T]
                                'history_motion': history_motion[start_idx:end_idx, ...],
                                'history_mask': history_mask[start_idx:end_idx, ...],
                                'transf_rotmat': transf_rotmat[person][start_idx:end_idx, ...],
                                'transf_transl': transf_transl[person][start_idx:end_idx, ...],
                                'start_frame': gender_seq_dict[person]['start_frame'][start_idx:end_idx],
                                'total_frames': gender_seq_dict[person]['total_frames'][start_idx:end_idx],
                                'history_length': self.cfg.history_length,
                                'future_length': self.cfg.future_length,
                            })
                        if self.use_indi_text:
                            gender_batch[person][-1]['texts'] = primitive_texts
                            if self.load_text_embedding:
                                gender_batch[person][-1]['text_embedding'] = text_embedding
                                gender_batch[person][-1]['text_mask'] = text_mask                            
                        if self.padding:
                            gender_batch[person][-1]['primitive_padding_mask'] = gender_seq_dict[person]['primitive_padding_mask'][start_idx:end_idx, ...]
                gender_batch['interaction'] = []
                for primitive_idx in range(self.num_primitive):        
                    start_idx = primitive_idx * gender_batch_size
                    end_idx = (primitive_idx + 1) * gender_batch_size
                    primitive_texts = gender_seq_texts['interaction'][start_idx:end_idx]
                    if self.load_text_embedding:
                        unseen_texts = [text for text in primitive_texts if text not in self.text_embedding_dict['interaction']]
                        if len(unseen_texts) > 0:
                            self.update_text_embedding_dict(unseen_texts, 'interaction', text_sep=self.text_sep, max_segs=self.max_segs)
                        text_embedding = torch.stack([self.text_embedding_dict['interaction'][text] for text in primitive_texts], dim=0)  # [B, 512]
                        if self.text_sep:
                            text_mask = torch.stack([self.text_mask_dict['interaction'][text] for text in primitive_texts], dim=0)  # [B, max_segs]
                        else:
                            text_mask = None
                    gender_batch['interaction'].append({
                            'texts': primitive_texts,
                            'rel_pose_b2a': rel_pose['b2a'][start_idx:end_idx],
                            'rel_pose_a2b': rel_pose['a2b'][start_idx:end_idx],
                        })
                    if self.load_text_embedding:
                        gender_batch['interaction'][-1]['text_embedding'] = text_embedding
                        gender_batch['interaction'][-1]['text_mask'] = text_mask
                   
            elif self.mode == 'sep':
                for person in ['person1', 'person2']:
                    motion_tensor_normalized = self.normalize(self.dict_to_tensor(feature_dict[person]))
                    motion_tensor_normalized = motion_tensor_normalized.permute(0, 2, 1).unsqueeze(2)                   # [B*num_mp, D, 1, T]
                    history_mask = torch.zeros_like(motion_tensor_normalized, dtype=torch.bool, device=self.device)
                    history_mask[..., :self.cfg.history_length] = True
                    history_motion = torch.zeros_like(motion_tensor_normalized, dtype=torch.float32, device=self.device)
                    history_motion[..., :self.cfg.history_length] = motion_tensor_normalized[..., :self.cfg.history_length]

                    for primitive_idx in range(self.num_primitive):
                        start_idx = primitive_idx * gender_batch_size
                        end_idx = (primitive_idx + 1) * gender_batch_size
                        if self.use_indi_text:
                            primitive_texts = gender_seq_texts[person][start_idx:end_idx]
                            if self.load_text_embedding:
                                unseen_texts = [text for text in primitive_texts if text not in self.text_embedding_dict[person]]
                                if len(unseen_texts) > 0:
                                    self.update_text_embedding_dict(unseen_texts, person, text_sep=self.text_sep, max_segs=self.max_segs)
                                    # new_text_embeddings = encode_text(self.clip_model, unseen_texts)
                                    # for idx, text in enumerate(unseen_texts):
                                    #     self.text_embedding_dict[person][text] = new_text_embeddings[idx]
                                text_embedding = torch.stack([self.text_embedding_dict[person][text] for text in primitive_texts], dim=0)  # [B, 512]
                                if self.text_sep:
                                    text_mask = torch.stack([self.text_mask_dict[person][text] for text in primitive_texts], dim=0)  # [B, max_segs]
                                else:
                                    text_mask = None
                        else:
                            primitive_texts = gender_seq_texts['interaction'][start_idx:end_idx]
                            if self.load_text_embedding:
                                unseen_texts = [text for text in primitive_texts if text not in self.text_embedding_dict['interaction']]
                                if len(unseen_texts) > 0:
                                    self.update_text_embedding_dict(unseen_texts, 'interaction', text_sep=self.text_sep, max_segs=self.max_segs)
                                text_embedding = torch.stack([self.text_embedding_dict['interaction'][text] for text in primitive_texts], dim=0)  # [B, 512]
                                if self.text_sep:
                                    text_mask = torch.stack([self.text_mask_dict['interaction'][text] for text in primitive_texts], dim=0)  # [B, max_segs]
                                else:
                                    text_mask = None

                        gender_batch.append(
                            {
                                'gender': [gender_seq_dict[person]['gender']] * gender_batch_size,
                                'betas': gender_seq_dict[person]['betas'][start_idx:end_idx, :-1, :10],
                                'motion_tensor_normalized': motion_tensor_normalized[start_idx:end_idx, ...], # [B, D, 1, T]
                                'history_motion': history_motion[start_idx:end_idx, ...],
                                'history_mask': history_mask[start_idx:end_idx, ...],
                                'transf_rotmat': transf_rotmat[person][start_idx:end_idx, ...],
                                'transf_transl': transf_transl[person][start_idx:end_idx, ...],
                                'history_length': self.cfg.history_length,
                                'future_length': self.cfg.future_length,
                            }
                        )
                        if self.use_indi_text:
                            gender_batch[-1]['texts'] = primitive_texts
                        if self.load_text_embedding:
                            gender_batch[-1]['text_embedding'] = text_embedding
                            gender_batch[-1]['text_mask'] = text_mask
                        if self.padding:
                            gender_batch[-1]['primitive_padding_mask'] = gender_seq_dict[person]['primitive_padding_mask'][start_idx:end_idx, ...]
                selector = torch.cat([torch.ones(gender_batch_size), torch.zeros(gender_batch_size)])
                selector = selector[torch.randperm(2 * gender_batch_size)]
                
                front_group, back_group = {}, {}
                for key in add_key_list:
                    front_group[key], back_group[key] = [], []
                    for d in gender_batch[:self.num_primitive]:
                        front_group[key] += d[key]
                    for d in gender_batch[self.num_primitive:]:
                        back_group[key] += d[key]
                for key in cat_key_list:
                    front_group[key] = torch.cat([d[key] for d in gender_batch[:self.num_primitive]], dim=0)
                    back_group[key] = torch.cat([d[key] for d in gender_batch[self.num_primitive:]], dim=0)

                front_indices = torch.nonzero(selector[:gender_batch_size], as_tuple=True)[0]  
                back_indices = torch.nonzero(selector[gender_batch_size:], as_tuple=True)[0]  

                selected_batch = []
                for i in range(self.num_primitive):
                    selected_dict = {'history_length': self.cfg.history_length,
                                'future_length': self.cfg.future_length,}
                    for key in front_group.keys():    
                        if key in add_key_list:
                            selected_front = [front_group[key][i] for i in front_indices + i * gender_batch_size] 
                            selected_back = [back_group[key][i] for i in back_indices + i * gender_batch_size]
                            selected_dict[key] = selected_front + selected_back
                        elif key in cat_key_list:
                            selected_front = front_group[key][front_indices + i * gender_batch_size] 
                            selected_back = back_group[key][back_indices + i * gender_batch_size]
                            selected_dict[key] = torch.cat([selected_front, selected_back], dim=0)  
                    selected_batch.append(selected_dict)
                gender_batch = selected_batch
                            
            if batch is None:
                batch = gender_batch
            else:  # concatenate different gender batch
                for primitive_idx in range(self.num_primitive):
                    if self.mode == 'merged':
                        for key_type in self.key_list:
                            if key_type != 'interaction':
                                for key in add_key_list:
                                    batch[key_type][primitive_idx][key] = batch[key_type][primitive_idx][key] + gender_batch[key_type][primitive_idx][key]
                                for key in cat_key_list:
                                    batch[key_type][primitive_idx][key] = torch.cat([batch[key_type][primitive_idx][key], gender_batch[key_type][primitive_idx][key]], dim=0)
                            else:
                                for key in ['texts']:
                                    batch[key_type][primitive_idx][key] = batch[key_type][primitive_idx][key] + gender_batch[key_type][primitive_idx][key]
                                for key in ['rel_pose_b2a', 'rel_pose_a2b']:
                                    batch[key_type][primitive_idx][key] = torch.cat([batch[key_type][primitive_idx][key], gender_batch[key_type][primitive_idx][key]], dim=0)
                                if self.load_text_embedding:
                                    batch[key_type][primitive_idx]['text_embedding'] = torch.cat([batch[key_type][primitive_idx]['text_embedding'], gender_batch[key_type][primitive_idx]['text_embedding']], dim=0)
                                    batch[key_type][primitive_idx]['text_mask'] = torch.cat([batch[key_type][primitive_idx]['text_mask'], gender_batch[key_type][primitive_idx]['text_mask']], dim=0)   

                    else:
                        for key in add_key_list:
                            batch[primitive_idx][key] = batch[primitive_idx][key] + gender_batch[primitive_idx][key]
                        for key in cat_key_list:
                            batch[primitive_idx][key] = torch.cat([batch[primitive_idx][key], gender_batch[primitive_idx][key]], dim=0)
            # if self.mode == 'merged':
            #     if random.random() < 0.5:
            #         batch['person1'], batch['person2'] = batch['person2'], batch['person1']

        return batch
    
    def get_item(self, idx):
        seq_data = self.dataset[idx]
        num_frames = len(seq_data['motion_p1']['transl'])
        gender = {}
        gender['person1'] = seq_data['motion_p1']['gender']
        gender['person2'] = seq_data['motion_p2']['gender']
        if 'text' in self.weight_scheme:
            start_frame = random.choices(range(num_frames - self.seq_length + 1), weights=seq_data['frame_weights'], k=1)[0]
        else:
            start_frame = random.randint(0, num_frames - self.seq_length)  # [0, num_frames - seq_length], right end inclusive
        primitive_data_list = []
        for frame_idx in range(start_frame, start_frame + self.seq_length - self.primitive_length, self.future_length):
            primitive_data = self.get_primitive(seq_data, frame_idx, frame_idx + self.primitive_length)
            primitive_data_list.append(primitive_data)
        
        gender_seq_texts = {key_type: [] for key_type in self.key_list}
        gender_seq_dict = None
        for primitive_idx in range(self.num_primitive):
            primitive_dict = {}
            for person in ['person1', 'person2']:
                primitive_dict[person] = {}
                primitive_dict[person]['gender'] = gender[person]
                for key in ['betas', 'transl', 'global_orient', 'body_pose', 'transf_rotmat', 'transf_transl', 'pelvis_delta', 'joints',
                            'jaw_pose', 'left_eye_pose', 'right_eye_pose', 'left_hand_pose', 'right_hand_pose']:
                    primitive_dict[person][key] = primitive_data_list[primitive_idx]['primitive_dict'][person][key]
                if self.padding:
                    primitive_dict[person]['primitive_padding_mask'] = primitive_data_list[primitive_idx]['primitive_dict'][person]['primitive_padding_mask']
            primitive_texts = {}
            for key_type in self.text_key_list:
                primitive_texts[key_type] = primitive_data_list[primitive_idx]['text_'+key_type]
                gender_seq_texts[key_type].append(primitive_texts[key_type])
            
            if gender_seq_dict is None:
                gender_seq_dict = primitive_dict
            else:
                for person in ['person1', 'person2']:
                    for key in ['betas', 'transl', 'global_orient', 'body_pose', 'transf_rotmat', 'transf_transl', 'pelvis_delta', 'joints',
                                'jaw_pose', 'left_eye_pose', 'right_eye_pose', 'left_hand_pose', 'right_hand_pose']:
                        gender_seq_dict[person][key] = torch.cat([gender_seq_dict[person][key], primitive_dict[person][key]], dim=0)
                    if self.padding:
                        gender_seq_dict[person]['primitive_padding_mask'] = torch.cat([gender_seq_dict[person]['primitive_padding_mask'],
                                                                                           primitive_dict[person]['primitive_padding_mask']], dim=0)

        canonicalized_primitive_dict = {}
        if self.mode == 'merged':
            transf_rotmat, transf_transl = {}, {}
        for person in ['person1', 'person2']:
            gender_seq_dict[person] = tensor_dict_to_device(gender_seq_dict[person], self.device)
            if self.mode == 'merged':
                transf_rotmat[person], transf_transl[person], canonicalized_primitive_dict[person] = self.primitive_utility.canonicalize(copy.deepcopy(gender_seq_dict[person]), use_predicted_joints=True)            
            else:
                _, _, canonicalized_primitive_dict[person] = self.primitive_utility.canonicalize(copy.deepcopy(gender_seq_dict[person]), use_predicted_joints=True)          
        
        if self.mode == 'merged':
            # rel_pose
            rel_rotmat, rel_transl, rel_pose = {}, {}, {}
            rel_rotmat['b2a'], rel_transl['b2a'] = self.primitive_utility.compute_rel_transform_B_in_A(transf_rotmat['person1'], transf_transl['person1'], transf_rotmat['person2'], transf_transl['person2'])
            rel_rotmat['a2b'], rel_transl['a2b'] = self.primitive_utility.compute_rel_transform_B_in_A(transf_rotmat['person2'], transf_transl['person2'], transf_rotmat['person1'], transf_transl['person1'])
            for rel in ['b2a', 'a2b']:
                rel_rotmat[rel] = transforms.matrix_to_rotation_6d(rel_rotmat[rel])
                rel_pose[rel] = torch.cat([rel_rotmat[rel], rel_transl[rel].squeeze(1)], dim=-1)  # [B*num_mp, 6+3]
        
        # calc features
        feature_dict = {}
        for person in ['person1', 'person2']:
            feature_dict[person] = self.primitive_utility.calc_features(canonicalized_primitive_dict[person], use_predicted_joints=True)
            if self.primitive_utility.feature_dim == 276:
                feature_dict[person]['transl'] = feature_dict[person]['transl'][:, :-1, :]  # [B*num_mp, T, 3]
                feature_dict[person]['poses_6d'] = feature_dict[person]['poses_6d'][:, :-1, :]  # [B*num_mp, T, 66]
                feature_dict[person]['joints'] = feature_dict[person]['joints'][:, :-1, :]  # [B*num_mp, T, 22 * 3]
            elif self.primitive_utility.feature_dim == 56 * 6:
                feature_dict[person]['transl'] = feature_dict[person]['transl'][:, :-1, :]  # [B*num_mp, T, 3]
                feature_dict[person]['body_pose'] = feature_dict[person]['body_pose'][:, :-1, :]  # [B*num_mp, T, 55*6]
            elif self.primitive_utility.feature_dim == 55 * 12:
                feature_dict[person]['joints'] = feature_dict[person]['joints'][:, :-1, :]  # [B*num_mp, T, 55 * 3]
                feature_dict[person]['body_pose'] = feature_dict[person]['body_pose'][:, :-1, :]  # [B*num_mp, T, 55*6]

        
        data_batch = {} if self.mode == 'merged' else []
        if self.mode == 'merged':
            for person in ['person1', 'person2']:
                data_batch[person] = []
                motion_tensor_normalized = self.normalize(self.dict_to_tensor(feature_dict[person]))     # [B*num_mp, T, D]
                motion_tensor_normalized = motion_tensor_normalized.permute(0, 2, 1).unsqueeze(2)           # [B*num_mp, D, 1, T]
                history_mask = torch.zeros_like(motion_tensor_normalized, dtype=torch.bool, device=self.device)
                history_mask[..., :self.cfg.history_length] = True
                history_motion = torch.zeros_like(motion_tensor_normalized, dtype=torch.float32, device=self.device)
                history_motion[..., :self.cfg.history_length] = motion_tensor_normalized[..., :self.cfg.history_length]

                for primitive_idx in range(self.num_primitive):
                    start_idx = primitive_idx 
                    end_idx = primitive_idx + 1
                    if self.use_indi_text:
                        primitive_texts = gender_seq_texts[person][start_idx:end_idx]
                        if self.load_text_embedding:
                            unseen_texts = [text for text in primitive_texts if text not in self.text_embedding_dict[person]]
                            if len(unseen_texts) > 0:
                                self.update_text_embedding_dict(unseen_texts, person, text_sep=self.text_sep, max_segs=self.max_segs)
                            text_embedding = torch.stack([self.text_embedding_dict[person][text] for text in primitive_texts], dim=0)  # [B, 512]
                            if self.text_sep:
                                text_mask = torch.stack([self.text_mask_dict[person][text] for text in primitive_texts], dim=0)  # [B, max_segs]
                            else:
                                text_mask = None
            
                    data_batch[person].append({
                            'texts': primitive_texts,
                            'gender': [gender_seq_dict[person]['gender']],
                            'betas': gender_seq_dict[person]['betas'][start_idx:end_idx, :-1, :10],
                            'motion_tensor_normalized': motion_tensor_normalized[start_idx:end_idx, ...], # [B, D, 1, T]
                            'history_motion': history_motion[start_idx:end_idx, ...],
                            'history_mask': history_mask[start_idx:end_idx, ...],
                            'transf_rotmat': transf_rotmat[person][start_idx:end_idx, ...],
                            'transf_transl': transf_transl[person][start_idx:end_idx, ...],
                            'history_length': self.cfg.history_length,
                            'future_length': self.cfg.future_length,
                        })
                    if self.use_indi_text:
                        data_batch[person][-1]['texts'] = primitive_texts
                        if self.load_text_embedding:
                            data_batch[person][-1]['text_embedding'] = text_embedding
                            data_batch[person][-1]['text_mask'] = text_mask
                    if self.padding:
                        data_batch[person][-1]['primitive_padding_mask'] = gender_seq_dict[person]['primitive_padding_mask'][start_idx:end_idx, ...]
            data_batch['interaction'] = []
            for primitive_idx in range(self.num_primitive):        
                start_idx = primitive_idx
                end_idx = (primitive_idx + 1)
                primitive_texts = gender_seq_texts['interaction'][start_idx:end_idx]
                if self.load_text_embedding:
                    unseen_texts = [text for text in primitive_texts if text not in self.text_embedding_dict['interaction']]
                    if len(unseen_texts) > 0:
                        self.update_text_embedding_dict(unseen_texts, 'interaction', text_sep=self.text_sep, max_segs=self.max_segs)
                    text_embedding = torch.stack([self.text_embedding_dict['interaction'][text] for text in primitive_texts], dim=0)  # [B, 512]
                    if self.text_sep:
                        text_mask = torch.stack([self.text_mask_dict['interaction'][text] for text in primitive_texts], dim=0)  # [B, max_segs]
                    else:
                        text_mask = None
                data_batch['interaction'].append({
                        'texts': primitive_texts,
                        'rel_pose_b2a': rel_pose['b2a'][start_idx:end_idx],
                        'rel_pose_a2b': rel_pose['a2b'][start_idx:end_idx],
                    })
                if self.load_text_embedding:
                    data_batch['interaction'][-1]['text_embedding'] = text_embedding
                    data_batch['interaction'][-1]['text_mask'] = text_mask
            return data_batch
        else:
            for person in ['person1', 'person2']:
                motion_tensor_normalized = self.normalize(self.dict_to_tensor(feature_dict[person]))     # [num_mp, T, D]
                motion_tensor_normalized = motion_tensor_normalized.permute(0, 2, 1).unsqueeze(2)           # [num_mp, D, 1, T]
                history_mask = torch.zeros_like(motion_tensor_normalized, dtype=torch.bool, device=self.device)
                history_mask[..., :self.cfg.history_length] = True
                history_motion = torch.zeros_like(motion_tensor_normalized, dtype=torch.float32, device=self.device)
                history_motion[..., :self.cfg.history_length] = motion_tensor_normalized[..., :self.cfg.history_length]

                for primitive_idx in range(self.num_primitive):
                    start_idx = primitive_idx
                    end_idx = primitive_idx + 1
                    if self.use_indi_text:
                        primitive_texts = gender_seq_texts[person][start_idx:end_idx]
                        if self.load_text_embedding:
                            unseen_texts = [text for text in primitive_texts if text not in self.text_embedding_dict[person]]
                            if len(unseen_texts) > 0:
                                new_text_embeddings = encode_text(self.clip_model, unseen_texts)
                                for idx, text in enumerate(unseen_texts):
                                    self.text_embedding_dict[person][text] = new_text_embeddings[idx]
                            text_embedding = torch.stack([self.text_embedding_dict[person][text] for text in primitive_texts], dim=0)  # [B, 512]
                            if self.text_sep:
                                text_mask = torch.stack([self.text_mask_dict[person][text] for text in primitive_texts], dim=0)  # [B, max_segs]
                            else:
                                text_mask = None
                    data_batch.append(
                        {
                            'gender': [gender_seq_dict[person]['gender']],
                            'betas': gender_seq_dict[person]['betas'][start_idx:end_idx, :-1, :10],
                            'motion_tensor_normalized': motion_tensor_normalized[start_idx:end_idx, ...], # [1, D, 1, T]
                            'history_motion': history_motion[start_idx:end_idx, ...],
                            'history_mask': history_mask[start_idx:end_idx, ...],
                            'history_length': self.cfg.history_length,
                            'future_length': self.cfg.future_length,
                        }
                    )
                    if self.use_indi_text:
                        data_batch[-1]['texts'] = primitive_texts
                        if self.load_text_embedding:
                            data_batch[-1]['text_embedding'] = text_embedding
                            data_batch[-1]['text_mask'] = text_mask
                    if self.padding:
                        data_batch[-1]['primitive_padding_mask'] = gender_seq_dict[person]['primitive_padding_mask'][start_idx:end_idx, ...]
            if random.random() < 0.5:
                return data_batch[:self.num_primitive]
            else:
                return data_batch[self.num_primitive:]


# dataset = InterXDataset(enforce_gender=None,
#                             enforce_zero_beta=0,
#                             device='cuda:7',
#                             mode='merged',
#                             text_sep=True,
#                             split='train',
#                             padding=True,)
# # dataset.calc_mean_std()
# batch = dataset.get_batch(4)
# # dataset.get_item(0)

class InterXDatasetEval(data.Dataset):
    def __init__(self, dataset_name='interx',
                 dataset_path='./data/Inter-X/seq_data_single_interaction_fps30',
                 cfg_path='./config_files/config_hydra/motion_primitive/interx_h2_f8_r4.yaml',
                 split="test",
                 device='cuda',
                 load_data=True,
                 enforce_gender='male',
                 enforce_zero_beta = True, 
                 body_type='smplx',
                 text_sep = True,
                 max_segs = 20,
                 min_length=24,
                 max_length=150,
                 max_text_len=35,
                 unit_length=4,
                 motion_repr = {
                    'body_pose': 55 * 6,
                    'transl': 3,
                    'transl_delta': 3,
                 },
                 #  motion_repr = {
                 #      'joints': 55 * 3,
                 #      'joints_delta': 55 * 3,
                 #      'body_pose': 55 * 6,
                 # },
                 padding=False,
                 w_vectorizer=None,
                 **kwargs):
        self.dataset_name = dataset_name
        self.dataset_path = dataset_path
        self.split = split
        self.device = device
        self.enforce_gender = enforce_gender
        self.enforce_zero_beta = enforce_zero_beta
        self.text_sep = text_sep
        self.max_segs = max_segs
        self.min_length = min_length
        self.max_motion_length = max_length
        self.max_length = 20
        self.pointer = 0
        self.max_text_len = max_text_len
        self.unit_length = unit_length
        self.w_vectorizer = w_vectorizer
        self.padding = padding
        
        self.clip_version = kwargs.get('clip_version', 'ViT-B/32')
        self.load_text_embedding = kwargs.get('load_text_embedding', False)
        self.use_indi_text = kwargs.get('use_indi_text', False)
        
        self.key_list = ['person1', 'person2', 'interaction']
        self.text_key_list = ['person1', 'person2', 'interaction'] if self.use_indi_text else ['interaction']
        
        with open(cfg_path, 'r') as f:
            self.cfg = OmegaConf.load(f)
        self.target_fps = self.cfg.fps
        self.history_length = self.cfg.history_length
        self.future_length = self.cfg.future_length

        self.primitive_utility = PrimitiveUtility(device=self.device, body_type=body_type, motion_repr=motion_repr)
        self.motion_repr = self.primitive_utility.motion_repr

        if load_data:
            data_dict = {}
            id_list = []
            split_file = dataset_path.rsplit('/', 1)[0] + '/' + 'inter-x/splits' + f'/{split}.txt'
            with cs.open(split_file, 'r') as f:
                for line in f.readlines():
                    id_list.append(line.strip())
            new_name_list = []
            length_list = []
            with open(pjoin(dataset_path, f'{split}.pkl'), 'rb') as f:
                dataset = pickle.load(f)
            texts_processed_path = dataset_path.rsplit('/', 1)[0] + '/' + 'inter-x/texts_processed'
            for data in tqdm(dataset, desc='Processing dataset'):
                T = data['motion_p1']['trans'].shape[0]
                if T < self.min_length or T > 1000:
                    continue
                def convert_motion(motion, gender, enforce_zero_beta):
                    betas = torch.from_numpy(motion['betas'].astype(np.float32))
                    if enforce_zero_beta:
                        betas = torch.zeros_like(betas)
                    transl = torch.from_numpy(motion['trans'].astype(np.float32))
                    poses = torch.from_numpy(motion['poses'].astype(np.float32))
                    global_orient, body_pose, jaw_pose, leye_pose, reye_pose, pose_lhand, pose_rhand = \
                        transforms.axis_angle_to_matrix(poses[:, :3]), \
                        transforms.axis_angle_to_matrix(poses[:, 3:66].reshape(-1, 21, 3)), \
                        transforms.axis_angle_to_matrix(poses[:, 66:69]), \
                        transforms.axis_angle_to_matrix(poses[:, 69:72]), \
                        transforms.axis_angle_to_matrix(poses[:, 72:75]), \
                        transforms.axis_angle_to_matrix(poses[:, 75:120].reshape(-1, 15, 3)), \
                        transforms.axis_angle_to_matrix(poses[:, 120:].reshape(-1, 15, 3))            
                    pelvis_delta = torch.from_numpy(motion['pelvis_delta'].astype(np.float32))              # [3]
                    joints = torch.from_numpy(motion['joints'].astype(np.float32))                          # [T, 55, 3]
                    result = {
                        'gender': gender,
                        'betas': betas.unsqueeze(0).expand(transl.shape[0], 10),
                        'transl': transl,
                        'global_orient': global_orient,
                        'body_pose': body_pose,
                        'jaw_pose': jaw_pose,
                        'left_eye_pose': leye_pose,
                        'right_eye_pose': reye_pose,
                        'left_hand_pose': pose_lhand,
                        'right_hand_pose': pose_rhand,
                        'pelvis_delta': pelvis_delta,
                        'joints': joints,
                    }
                    return result
                gender_p1 = self.enforce_gender if self.enforce_gender is not None else data['motion_p1']['gender']
                gender_p2 = self.enforce_gender if self.enforce_gender is not None else data['motion_p2']['gender']
                data['motion_p1'] = convert_motion(data['motion_p1'], gender_p1, self.enforce_zero_beta)
                data['motion_p2'] = convert_motion(data['motion_p2'], gender_p2, self.enforce_zero_beta)
                text_data = []
                flag=False
                with cs.open(pjoin(texts_processed_path, f'{data["seq_name"]}.txt'), 'r') as f:
                    for line in f.readlines():
                        text_dict = {}
                        line_split = line.strip().split('#')
                        caption = line_split[0]
                        tokens = line_split[1].split(' ')
                        f_tag = float(line_split[2])
                        to_tag = float(line_split[3])
                        f_tag = 0.0 if np.isnan(f_tag) else f_tag
                        to_tag = 0.0 if np.isnan(to_tag) else to_tag

                        text_dict['caption'] = caption
                        text_dict['tokens'] = tokens
                        if f_tag == 0.0 and to_tag == 0.0:
                            flag = True
                            text_data.append(text_dict)
                        else:
                            exit(-1)
                if flag:
                    data['text'] = text_data
                    data['length'] = T
                    data_dict[data['seq_name']] = data
                    new_name_list.append(data['seq_name'])
                    length_list.append(T)
            print('num of sequences: ', len(data_dict))
            
            name_list, length_list = zip(*sorted(zip(new_name_list, length_list), key=lambda x: x[1]))
            self.length_arr = np.array(length_list)
            self.data_dict = data_dict
            self.name_list = name_list
            self.reset_max_len(self.max_length)
            self.data_dict = data_dict
                    
        self.tensor_mean_device_dict = {}
        suffix = '_padding' if padding else ''
        mean_std_path = Path(dataset_path, f'mean_std_h{self.history_length}_f{self.future_length}{suffix}.pkl')
        try:
            print(f'loading mean and std from {mean_std_path}')
            with open(mean_std_path, 'rb') as f:
                self.tensor_mean, self.tensor_std = pickle.load(f)  # [1, 1, D]
        except FileNotFoundError:
            print('Error: mean and std not found!')
            
        # load clip model, get train text embeddings
        if self.load_text_embedding:
            self.load_and_freeze_clip(clip_version=self.clip_version, device=self.device)
            self.dim_embed_text = self.clip_model.ln_final.normalized_shape[0]
            suffix = '' if self.clip_version == 'ViT-B/32' else f"_{self.clip_version.replace('/', '')}"
            self.embedding_path = {}
            embedding_path = {}
            for key_type in self.text_key_list:
                if text_sep:
                    self.embedding_path[key_type] = embedding_path[key_type] = Path(dataset_path, f'{split}_{key_type}_text_embedding_dict_textsep{suffix}.pkl')
                else:
                    self.embedding_path[key_type] = embedding_path[key_type] = Path(dataset_path, f'{split}_{key_type}_text_embedding_dict{suffix}.pkl')
            self.text_embedding_dict = {}
            if text_sep:
                self.text_mask_dict = {}
            
            for key_type in self.text_key_list:
                if embedding_path[key_type].exists():
                    print(f"Loading text_{key_type} embeddings from {embedding_path[key_type]}!")
                    with open(embedding_path[key_type], 'rb') as f:
                        self.text_embedding_dict[key_type] = pickle.load(f)
                    if text_sep:
                        with open(Path(str(embedding_path[key_type]).replace('text_embedding_dict', 'text_mask_dict')), 'rb') as f:
                            self.text_mask_dict[key_type] = pickle.load(f)
                else:
                    print('Calculating text embeddings')
                    raw_texts = []
                    for data in self.data_dict.values():
                        if f'frame_labels_{key_type}' in data:
                            raw_texts.extend([seg['proc_label'].strip() for seg in data['frame_labels_' + key_type]])

                    raw_texts = list(set(raw_texts))
                    num_texts = len(raw_texts)
                    print(f'num of unique texts_{key_type}: ', len(raw_texts))
                        
                    # get text embeddings by batch
                    text_embeddings = []
                    text_mask = []
                    batch_start_idx = 0
                    while batch_start_idx < num_texts:
                        batch_end_idx = min(batch_start_idx + 256, num_texts)
                        text_embeddings_temp = self.encode_text(raw_texts[batch_start_idx:batch_end_idx], text_sep=text_sep, max_segs=max_segs)
                        if text_sep:
                            text_embeddings.append(text_embeddings_temp[0])
                            text_mask.append(text_embeddings_temp[1])
                        else:
                            text_embeddings.append(text_embeddings_temp)
                        batch_start_idx = batch_end_idx
                    text_embeddings = torch.cat(text_embeddings, dim=0).detach().cpu().numpy()
                
                    self.text_embedding_dict[key_type] = {raw_texts[idx]: text_embeddings[idx] for idx in range(num_texts)}
                    if text_sep:
                        self.text_embedding_dict[key_type][''] = np.zeros((self.max_segs, self.dim_embed_text)).astype(np.float32)
                    else:
                        self.text_embedding_dict[key_type][''] = np.zeros(self.dim_embed_text).astype(np.float32)  # for empty text have zero embedding, compatible with mdm text masking
                    with open(embedding_path[key_type], 'wb') as f:
                        pickle.dump(self.text_embedding_dict[key_type], f)
                    if text_sep:
                        text_mask = torch.cat(text_mask, dim=0).detach().cpu().numpy()
                        self.text_mask_dict[key_type] = {raw_texts[idx]: text_mask[idx] for idx in range(num_texts)}
                        self.text_mask_dict[key_type][''] = np.zeros(max_segs).astype(np.bool_)  # for empty text have zero embedding, compatible with mdm text masking
                        with open(Path(str(embedding_path[key_type]).replace('text_embedding_dict', 'text_mask_dict')), 'wb') as f:
                            pickle.dump(self.text_mask_dict[key_type], f)
                
                for key in self.text_embedding_dict[key_type]:
                    self.text_embedding_dict[key_type][key] = torch.from_numpy(self.text_embedding_dict[key_type][key]).to(dtype=torch.float32, device='cpu')
                    if text_sep:
                        self.text_mask_dict[key_type][key] = torch.from_numpy(self.text_mask_dict[key_type][key]).to(dtype=torch.bool, device='cpu')

    def reset_max_len(self, length):
        assert length <= self.max_motion_length
        self.pointer = np.searchsorted(self.length_arr, length)
        print("Pointer Pointing at %d"%self.pointer)
        self.max_length = length

    def load_and_freeze_clip(self, clip_version, device='cpu'):
        self.clip_model, _= clip.load(clip_version, device=device,
                                                jit=False)  # Must set jit=False for training
        clip.model.convert_weights(self.clip_model)  # Actually this line is unnecessary since clip by default already on float16

        # Freeze CLIP weights
        self.clip_model.eval()
        for p in self.clip_model.parameters():
            p.requires_grad = False
    
    def encode_text(self, raw_text, force_empty_zero=True, text_sep=False, max_segs = 20, sep_mode=0):
        import pandas as pd
        device = next(self.clip_model.parameters()).device
        embed_dim = self.dim_embed_text
        batch_size = len(raw_text)

        if not text_sep:
            with torch.no_grad():
                texts = clip.tokenize(raw_text, truncate=True).to(device)  # [B, context_length]
                text_embedding = self.clip_model.encode_text(texts).float()  # [B, 512]
                if force_empty_zero:
                    empty_text = [t == '' for t in raw_text]
                    text_embedding[empty_text, :] = 0
                return text_embedding
                
        raw_series = pd.Series(raw_text).str.strip().str.rstrip('.')
        if sep_mode == 0:
            split_df = raw_series.str.split(r'[,.]', n=max_segs - 1, expand=True)
        elif sep_mode == 1:
            split_df = raw_series.str.split(r'\band\b|\bwhile\b|,|\.', n=max_segs - 1, expand=True)
        split_df = split_df.fillna('').astype(str).applymap(str.strip)

        split_df = split_df.reindex(columns=range(max_segs), fill_value='')
        
        segs_matrix = split_df.values
        segs_flat = segs_matrix.reshape(-1).tolist()

        text_mask = (segs_matrix == '').astype(bool)
        text_mask = torch.tensor(text_mask, dtype=torch.bool, device=device)

        tokenized = clip.tokenize(segs_flat, truncate=True).to(device)  # [B*max_segs, context_length]
        text_embedding = self.clip_model.encode_text(tokenized).float()      # [B*max_segs, 512]
        text_embedding = text_embedding.view(batch_size, max_segs, embed_dim)  # [B, max_segs, 512]

        if force_empty_zero:
            text_embedding[text_mask] = 0

        return text_embedding, text_mask

    def update_text_embedding_dict(self, new_texts, key_type, text_sep=False, max_segs=20):
        new_text_embeddings = self.encode_text(new_texts, text_sep=text_sep, max_segs=max_segs)
        for idx, text in enumerate(new_texts):
            if text_sep:
                self.text_embedding_dict[key_type][text] = new_text_embeddings[0][idx]
                self.text_mask_dict[key_type][text] = new_text_embeddings[1][idx]
            else:
                self.text_embedding_dict[key_type][text] = new_text_embeddings[idx]
    
    def get_mean_std_by_device(self, device):
        if device not in self.tensor_mean_device_dict:
            self.tensor_mean_device_dict[device] = (self.tensor_mean.to(device=device), self.tensor_std.to(device=device))
        return self.tensor_mean_device_dict[device]

    def normalize(self, tensor):
        tensor_mean, tensor_std = self.get_mean_std_by_device(tensor.device)
        tensor_std_safe = tensor_std.clone()
        tensor_std_safe[tensor_std == 0] = 1.0  # avoid division by zero
        return (tensor - tensor_mean) / tensor_std_safe  # [B, T, D]
    
    def denormalize(self, tensor):
        tensor_mean, tensor_std = self.get_mean_std_by_device(tensor.device)
        return tensor * tensor_std + tensor_mean  # [B, T, D]
    
    def __len__(self):
        return len(self.data_dict) - self.pointer
        
    def __getitem__(self, item):
        idx = self.pointer + item
        data = copy.deepcopy(self.data_dict[self.name_list[idx]])
        m_length, text_list = data['length'], data['text']
        text_data = random.choice(text_list)
        caption, tokens = text_data['caption'], text_data['tokens']
        if len(tokens) < self.max_text_len:
            # pad with "unk"
            tokens = ['sos/OTHER'] + tokens + ['eos/OTHER']
            sent_len = len(tokens)
            tokens = tokens + ['unk/OTHER'] * (self.max_text_len + 2 - sent_len)
        else:
            # crop
            tokens = tokens[:self.max_text_len]
            tokens = ['sos/OTHER'] + tokens + ['eos/OTHER']
            sent_len = len(tokens)
        pos_one_hots = []
        word_embeddings = []
        for token in tokens:
            try:
                word_emb, pos_oh = self.w_vectorizer[token]
            except:
                word_emb, pos_oh = self.w_vectorizer['unk/OTHER']
            pos_one_hots.append(pos_oh[None, :])
            word_embeddings.append(word_emb[None, :])
        pos_one_hots = np.concatenate(pos_one_hots, axis=0)
        word_embeddings = np.concatenate(word_embeddings, axis=0)
        
        # Crop the motions in to times of 4, and introduce small variations
        if self.unit_length < 10:
            coin2 = np.random.choice(['single', 'single', 'double'])
        else:
            coin2 = 'single'
   
        if coin2 == 'double':
            m_length = (m_length // self.unit_length - 1) * self.unit_length
        elif coin2 == 'single':
            m_length = (m_length // self.unit_length) * self.unit_length
        cut_idx = random.randint(0, data['motion_p1']['transl'].shape[0] - m_length)
        
        for person in ['motion_p1', 'motion_p2']:
            for key in data[person].keys():
                if key in ['betas', 'transl', 'global_orient', 'body_pose', 'jaw_pose', 'left_eye_pose', 'right_eye_pose', 'left_hand_pose', 'right_hand_pose', 'joints']:
                    data[person][key] = data[person][key][cut_idx:cut_idx + m_length]
        
        data['person1'] = data.pop('motion_p1')
        data['person2'] = data.pop('motion_p2')
        for person in ['person1', 'person2']:
            data[person]['transf_rotmat'] = torch.eye(3).unsqueeze(0)
            data[person]['transf_transl'] = torch.zeros(1, 1, 3)
        for person in ['person1', 'person2']:
            for key in ['betas', 'transl', 'global_orient', 'body_pose', 'jaw_pose', 'left_eye_pose', 'right_eye_pose', 'left_hand_pose', 'right_hand_pose', 'pelvis_delta', 'joints', 'transf_rotmat', 'transf_transl']:
                if key in data[person]:
                    data[person][key] = data[person][key].squeeze(0)
        
        # padding
        if m_length < self.max_motion_length + 1:
            padding_length = self.max_motion_length + 1 - m_length
            for person in ['person1', 'person2']:
                for key in data[person].keys():
                    if key in ['betas', 'transl', 'global_orient', 'body_pose', 'jaw_pose', 'left_eye_pose', 'right_eye_pose', 'left_hand_pose', 'right_hand_pose', 'joints']:
                        data[person][key] = torch.cat([data[person][key], torch.zeros(padding_length, *data[person][key].shape[1:])], dim=0)
        else:
            for person in ['person1', 'person2']:
                for key in data[person].keys():
                    if key in ['betas', 'transl', 'global_orient', 'body_pose', 'jaw_pose', 'left_eye_pose', 'right_eye_pose', 'left_hand_pose', 'right_hand_pose', 'joints']:
                        data[person][key] = data[person][key][:self.max_motion_length + 1]
            m_length = self.max_motion_length + 1
        data.pop('text')
        if self.use_indi_text:
            for person in ['person1', 'person2']:
                primitive_texts = random.choice([
                    itext['proc_label'] for itext in data[f'frame_labels_{person}']
                ])
                if self.load_text_embedding:
                    unseen_texts = [text for text in primitive_texts if text not in self.text_embedding_dict[person]]
                    if len(unseen_texts) > 0:
                        self.update_text_embedding_dict(unseen_texts, person, text_sep=self.text_sep, max_segs=self.max_segs)
                    text_embedding = torch.stack([self.text_embedding_dict[person][text] for text in primitive_texts], dim=0)  # [B, 512]
                    if self.text_sep:
                        text_mask = torch.stack([self.text_mask_dict[person][text] for text in primitive_texts], dim=0)  # [B, max_segs]
                    else:
                        text_mask = None
                    data[person]['text_embedding'] = text_embedding.detach().cpu()
                    data[person]['text_mask'] = text_mask.detach().cpu() if self.text_sep else None
                    data[person]['texts'] = primitive_texts
        else:
            if self.load_text_embedding:
                unseen_texts = [caption] if caption not in self.text_embedding_dict['interaction'] else []
                if len(unseen_texts) > 0:
                    try:
                        self.update_text_embedding_dict(unseen_texts, 'interaction', text_sep=self.text_sep, max_segs=self.max_segs)
                    except:
                        print('error text: ', caption)
                        print(data['seq_name'])
                text_embedding = self.text_embedding_dict['interaction'][caption].unsqueeze(0)
                if self.text_sep:
                    text_mask = self.text_mask_dict['interaction'][caption].unsqueeze(0)
                else:
                    text_mask = None
                data['interaction'] = {
                    'texts': caption,
                    'text_embedding': text_embedding.detach().cpu(),
                }
                if self.text_sep:
                    data['interaction']['text_mask'] = text_mask.detach().cpu()
        return word_embeddings, pos_one_hots, caption, sent_len, data, m_length, '_'.join(tokens)

# dataset = InterXDatasetEval(enforce_gender=None,
#                             enforce_zero_beta=0,
#                             device='cuda:1',
#                             split='test',
#                             padding=True,
#                             w_vectorizer=WordVectorizer(pjoin('./data/Inter-X/inter-x', 'glove'), 'hhi_vab'))
# data = dataset[0]