import sys
sys.path.insert(0, sys.path[0]+r"/../../../")

from data_loaders.humanml.data.dataset import *
from utils.misc_util import load_and_freeze_t5_encoder, encode_text_t5
import copy
import random
import clip
from utils.intergen_util import process_motion, process_motion_refined, cal_rel_rot


class InterHumanMotion(WeightedPrimitiveSequenceDataset):
    def __init__(self, dataset_name='interhuman',
                 dataset_path='./data/InterHuman/seq_data_zero_male',
                 cfg_path='./config_files/config_hydra/motion_primitive/hml_mp_h2_f8_r4.yaml',
                 split="train",
                 device='cuda',
                 weight_scheme='uniform',
                 prob_static=0.0,
                 enforce_gender=None,
                 enforce_zero_beta=None,
                 load_data=True,
                 text_tolerance=0.0,
                 body_type='smplh',
                 seed_only=False,
                 use_frame_weights=True,
                 mode='sep', # 'sep' or 'merged'
                 **kwargs):
        self.dataset_name = dataset_name
        self.dataset_path = dataset_path
        self.split = split
        self.device = device
        self.weight_scheme = weight_scheme
        self.prob_static = prob_static
        self.enforce_gender = enforce_gender
        self.enforce_zero_beta = enforce_zero_beta
        self.text_tolerance = text_tolerance
        self.seed_only = seed_only
        self.mode = mode
        print('enforce_gender: ', enforce_gender)
        print('enforce_zero_beta: ', enforce_zero_beta)

        motion_repr = {'transl': 3,
                       'poses_6d': 22 * 6,
                       'transl_delta': 3,
                       'global_orient_delta_6d': 6,
                       'joints': 22 * 3,
                       'joints_delta': 22 * 3,
                       }
        self.primitive_utility = PrimitiveUtility(device=self.device, body_type=body_type, motion_repr=motion_repr)
        self.motion_repr = self.primitive_utility.motion_repr

        # cfg_path = Path(dataset_path, 'config.yaml')
        with open(cfg_path, 'r') as f:
            self.cfg = OmegaConf.load(f)
        self.target_fps = self.cfg.fps
        # self.downsample_rate = 120 // self.target_fps
        self.history_length = self.cfg.history_length
        self.future_length = self.cfg.future_length
        self.primitive_length = self.history_length + self.future_length
        self.num_primitive = self.cfg.num_primitive
        if seed_only:
            assert self.num_primitive == 1
        self.seq_length = self.history_length + self.future_length * self.num_primitive + 1

        if load_data:
            with open(pjoin(dataset_path, f'{split}.pkl'), 'rb') as f:
                dataset = pickle.load(f)
            dataset = [data for data in dataset if len(data['motion_p1']['trans']) >= self.seq_length]

            for data in dataset:
                assert self.enforce_gender == data['motion_p1']['gender']
                assert self.enforce_gender == data['motion_p2']['gender']
                assert self.enforce_zero_beta
                gender_p1 = self.enforce_gender if self.enforce_gender is not None else data['motion_p1']['gender']
                gender_p2 = self.enforce_gender if self.enforce_gender is not None else data['motion_p2']['gender']
                betas_p1 =torch.from_numpy(data['motion_p1']['betas'].astype(np.float32))
                betas_p2 =torch.from_numpy(data['motion_p2']['betas'].astype(np.float32))
                if self.enforce_zero_beta:
                    betas_p1 = torch.zeros_like(betas_p1)
                    betas_p2 = torch.zeros_like(betas_p2)
                
                transl_p1 = torch.from_numpy(data['motion_p1']['trans'].astype(np.float32))
                poses_p1 = torch.from_numpy(data['motion_p1']['poses'].astype(np.float32))
                transl_p2 = torch.from_numpy(data['motion_p2']['trans'].astype(np.float32))
                poses_p2 = torch.from_numpy(data['motion_p2']['poses'].astype(np.float32))
                
                global_orient_p1 = transforms.axis_angle_to_matrix(poses_p1[:, :3])  # [T, 3, 3]
                body_pose_p1 = transforms.axis_angle_to_matrix(poses_p1[:, 3:66].reshape(-1, 21, 3))  # [T, 21, 3, 3]
                pelvis_delta_p1 = torch.from_numpy(data['motion_p1']['pelvis_delta'].astype(np.float32))  # [3]
                joints_p1 = torch.from_numpy(data['motion_p1']['joints'].astype(np.float32))  # [T, 22, 3]
                global_orient_p2 = transforms.axis_angle_to_matrix(poses_p2[:, :3])  # [T, 3, 3]
                body_pose_p2 = transforms.axis_angle_to_matrix(poses_p2[:, 3:66].reshape(-1, 21, 3))  # [T, 21, 3, 3]
                pelvis_delta_p2 = torch.from_numpy(data['motion_p2']['pelvis_delta'].astype(np.float32))  # [3]
                joints_p2 = torch.from_numpy(data['motion_p2']['joints'].astype(np.float32))  # [T, 22, 3]
                
                data['motion_p1'] = {
                    'gender': gender_p1,
                    'betas': betas_p1,
                    'transl': transl_p1,
                    'global_orient': global_orient_p1,
                    'body_pose': body_pose_p1,
                    'pelvis_delta': pelvis_delta_p1,
                    'joints': joints_p1,
                }
                data['motion_p2'] = {
                    'gender': gender_p2,
                    'betas': betas_p2,
                    'transl': transl_p2,
                    'global_orient': global_orient_p2,
                    'body_pose': body_pose_p2,
                    'pelvis_delta': pelvis_delta_p2,
                    'joints': joints_p2,
                }
            print('num of sequences: ', len(dataset))
            # assign sampling weights to each sequence


            for data in dataset:
                if 'uniform' in weight_scheme:
                    data['weight'] = 1.0
                elif 'length' in weight_scheme:
                    data['weight'] = len(data['motion_p1']['trans'])
            print('finish first assigning seq weights')

            # overfit using one sequence
            # if 'overfit' in weight_scheme:
            #     seq_id = int(weight_scheme.split('overfit:')[-1].split('_')[0])
            #     for idx, data in enumerate(dataset):
            #         if idx == seq_id:
            #             data['weight'] = 1.0
            #         else:
            #             data['weight'] = 0.0

            seq_weights = np.array([data['weight'] for data in dataset])
            seq_weights = seq_weights / seq_weights.sum()

            self.dataset = dataset
            self.seq_weights = seq_weights

        # load or calc mean and std
        self.tensor_mean_device_dict = {}
        file_name = f'mean_std_h{self.history_length}_f{self.future_length}_allcanonicalized'
        # TODO: use different mean and std when enforce gender and beta
        # if self.enforce_gender is not None:
        #     file_name = file_name + f'_{self.enforce_gender}'
        # if self.enforce_zero_beta:
        #     file_name = file_name + '_zero_beta'
        mean_std_path = Path(dataset_path, f'{file_name}.pkl')
        if mean_std_path.exists():
            print(f'loading mean and std from {mean_std_path}')
            with open(mean_std_path, 'rb') as f:
                self.tensor_mean, self.tensor_std = pickle.load(f)  # [1, 1, D]
        else:
            assert self.split == 'train'
            print('calculating mean and std using train split')
            self.tensor_mean, self.tensor_std = self.calc_mean_std()
            with open(mean_std_path, 'wb') as f:
                pickle.dump((self.tensor_mean.detach().cpu(), self.tensor_std.detach().cpu()), f)

        # load clip model, get train text embeddings
        self.clip_model = load_and_freeze_clip(clip_version='ViT-B/32', device=self.device)
        self.embedding_path = embedding_path = Path(dataset_path, f'{split}_text_embedding_dict.pkl')
        if embedding_path.exists():
            print(f'loading text embeddings from {embedding_path}')
            with open(embedding_path, 'rb') as f:
                self.text_embedding_dict = pickle.load(f)
        else:
            print('calculating text embeddings')
            raw_texts = []
            for data in self.dataset:
                if 'frame_labels' in data:
                    raw_texts.extend([seg['proc_label'] for seg in data['frame_labels']])
            raw_texts = list(set(raw_texts))
            num_texts = len(raw_texts)
            print('num of unique texts: ', len(raw_texts))
            # get text embeddings by batch
            text_embeddings = []
            batch_start_idx = 0
            while batch_start_idx < num_texts:
                batch_end_idx = min(batch_start_idx + 256, num_texts)
                text_embeddings.append(encode_text(self.clip_model, raw_texts[batch_start_idx:batch_end_idx]))
                batch_start_idx = batch_end_idx
            text_embeddings = torch.cat(text_embeddings, dim=0).detach().cpu().numpy()
            print(text_embeddings.shape)
            self.text_embedding_dict = {raw_texts[idx]: text_embeddings[idx] for idx in range(num_texts)}
            self.text_embedding_dict[''] = np.zeros(512).astype(np.float32)  # for empty text have zero embedding, compatible with mdm text masking
            with open(embedding_path, 'wb') as f:
                pickle.dump(self.text_embedding_dict, f)
        for key in self.text_embedding_dict:
            self.text_embedding_dict[key] = torch.from_numpy(self.text_embedding_dict[key]).to(dtype=torch.float32, device=self.device)

    def calc_mean_std(self, batch_size=512):
        if self.future_length == 1:
            batch_size = min(batch_size, 64)
        all_mp_data = []
        for seq_data in self.dataset:
            motion_data_p1 = seq_data['motion_p1']
            num_frames = motion_data_p1['transl'].shape[0]
            primitive_data_list = []
            for start_frame in range(0, num_frames - self.primitive_length, self.future_length):
                end_frame = start_frame + self.primitive_length
                primitive_data_list.append(self.get_primitive(seq_data, start_frame, end_frame, skip_text=True))
                
            primitive_dict = {}
            for person in ['person1', 'person2']:
                primitive_dict[person] = {}
                primitive_dict[person]['gender'] = {primitive_data_list[0]['primitive_dict'][person]['gender']}
                for key in ['betas', 'transl', 'global_orient', 'body_pose', 'transf_rotmat', 'transf_transl', 'pelvis_delta', 'joints']:
                    primitive_dict[person][key] = torch.cat([data['primitive_dict'][person][key] for data in primitive_data_list], dim=0)
                primitive_dict[person] = tensor_dict_to_device(primitive_dict[person], self.device)

            # split primitive_dict into batches
            batch_start_idx = 0
            while batch_start_idx < len(primitive_dict['person1']['transl']):
                batch_primitive_dict = {}
                canonicalized_primitive_dict = {}
                batch_end_idx = min(batch_start_idx + batch_size, len(primitive_dict['person1']['transl']))
                for person in ['person1', 'person2']:
                    batch_primitive_dict[person] = {}
                    batch_primitive_dict[person] = {key: primitive_dict[person][key][batch_start_idx:batch_end_idx] for key in ['betas', 'transl', 'global_orient', 'body_pose', 'transf_rotmat', 'transf_transl', 'pelvis_delta', 'joints']}
                    batch_primitive_dict[person]['gender'] = primitive_dict[person]['gender']
                    # transf_rotmat[person], transf_transl[person], canonicalized_primitive_dict[person] = self.primitive_utility.canonicalize(batch_primitive_dict[person], use_predicted_joints=True)
                    _, _, canonicalized_primitive_dict[person] = self.primitive_utility.canonicalize(batch_primitive_dict[person], use_predicted_joints=True)

                # transf_rotmat, transf_transl = {}, {}
                # transf_rotmat_p1, transf_transl_p1, canonicalized_primitive_dict['person1'] = self.primitive_utility.canonicalize(batch_primitive_dict['person1'], use_predicted_joints=True)
                # canonicalized_primitive_dict['person2'] = self.primitive_utility.relative_canonicalize(batch_primitive_dict['person2'], transf_rotmat_p1, transf_transl_p1)
                # transf_rotmat_p2, transf_transl_p2, canonicalized_primitive_dict['person2'] = self.primitive_utility.canonicalize(canonicalized_primitive_dict['person2'], use_predicted_joints=True)
                # transf_rotmat['person1'] = torch.eye(3).unsqueeze(0).expand(transf_rotmat_p2.shape[0], transf_rotmat_p2.shape[1], 3).to(self.device)
                # transf_rotmat['person2'] = transf_rotmat_p2.to(self.device)
                # transf_transl['person1'] = torch.zeros(3).unsqueeze(0).expand(transf_transl_p2.shape[0], transf_transl_p2.shape[1], 3).to(self.device)
                # transf_transl['person2'] = transf_transl_p2.to(self.device)

                feature_dict = {}
                motion_tensor = {}
                for person in ['person1', 'person2']:
                    feature_dict[person] = self.primitive_utility.calc_features(canonicalized_primitive_dict[person], use_predicted_joints=True)
                    feature_dict[person]['transl'] = feature_dict[person]['transl'][:, :-1, :]          # [num_primitive, T, 3]
                    feature_dict[person]['poses_6d'] = feature_dict[person]['poses_6d'][:, :-1, :]      # [num_primitive, T, 66]
                    feature_dict[person]['joints'] = feature_dict[person]['joints'][:, :-1, :]          # [num_primitive, T, 22 * 3]
                    # feature_dict[person]['transf_rotmat_6d'] = transforms.matrix_to_rotation_6d(transf_rotmat[person])
                    # feature_dict[person]['transf_transl'] = transf_transl[person]
                    motion_tensor[person] = self.dict_to_tensor(feature_dict[person]).detach().cpu()    # [num_primitive, T, D]
                    all_mp_data.append(motion_tensor[person])                                           # [num_primitive, T, D]

                batch_start_idx = batch_end_idx

        all_mp_data = torch.cat(all_mp_data, dim=0)                 # [2*N, T, D]
        tensor_mean = all_mp_data.mean(dim=[0, 1], keepdim=True)    # [1, 1, D]
        tensor_std = all_mp_data.std(dim=[0, 1], keepdim=True)      # [1, 1, D]
        return tensor_mean.to(self.device), tensor_std.to(self.device)

    def get_primitive(self, seq_data, start_frame, end_frame, skip_text=False):
        """end_frame included"""
        motion_data_p1 = seq_data['motion_p1']
        motion_data_p2 = seq_data['motion_p2']
        primitive_dict = {}
        primitive_dict['person1'] = {
            'gender': motion_data_p1['gender'],
            'betas': motion_data_p1['betas'].expand(1, self.primitive_length + 1, 10),
            'transl': motion_data_p1['transl'][start_frame:end_frame + 1].unsqueeze(0),  # include one more frame for delta feature calculation
            'global_orient': motion_data_p1['global_orient'][start_frame:end_frame + 1].unsqueeze(0),
            'body_pose': motion_data_p1['body_pose'][start_frame:end_frame + 1].unsqueeze(0),
            'pelvis_delta': motion_data_p1['pelvis_delta'].unsqueeze(0),
            'joints': motion_data_p1['joints'][start_frame:end_frame + 1].unsqueeze(0),
            'transf_rotmat': torch.eye(3).unsqueeze(0),
            'transf_transl': torch.zeros(1, 1, 3),
        }
        primitive_dict['person2'] = {
            'gender': motion_data_p2['gender'],
            'betas': motion_data_p2['betas'].expand(1, self.primitive_length + 1, 10),
            'transl': motion_data_p2['transl'][start_frame:end_frame + 1].unsqueeze(0),  # include one more frame for delta feature calculation
            'global_orient': motion_data_p2['global_orient'][start_frame:end_frame + 1].unsqueeze(0),
            'body_pose': motion_data_p2['body_pose'][start_frame:end_frame + 1].unsqueeze(0),
            'pelvis_delta': motion_data_p2['pelvis_delta'].unsqueeze(0),
            'joints': motion_data_p2['joints'][start_frame:end_frame + 1].unsqueeze(0),
            'transf_rotmat': torch.eye(3).unsqueeze(0),
            'transf_transl': torch.zeros(1, 1, 3),
        }

        texts = []
        if not skip_text and 'frame_labels' in seq_data:
            future_start = (start_frame + self.history_length) / self.target_fps
            future_end = (start_frame + self.history_length + self.future_length - 1) / self.target_fps
            # print('text tolerance: ', self.text_tolerance)
            for seg in seq_data['frame_labels']:
                if have_overlap([seg['start_t'], seg['end_t']], [future_start - self.text_tolerance, future_end + self.text_tolerance]):
                    texts.append(seg['proc_label'])
        # print('text label time: ', time.time() - self.time)

        output = {
            'text': random.choice(texts) if len(texts) > 0 else '',
            # 'text': compose_texts_with_and(texts) if len(texts) > 0 else '',
            'primitive_dict': primitive_dict,
        }
        return output

    def get_batch(self, batch_size=8):
        self.time = time.time()
        seq_list = []
        batch_idx = self.get_batch_idx(batch_size)
        add_key_list = ['texts', 'gender'] if self.mode=='sep' else ['texts', 'gender_p1', 'gender_p2']
        cat_key_list = ['betas', 'motion_tensor_normalized', 'history_motion', 'history_mask', 'text_embedding'] if self.mode=='sep' else ['betas_p1', 'betas_p2', 'motion_tensor_normalized', 'history_motion', 'history_mask', 'text_embedding']
        
        for seq_idx in batch_idx:
            seq_data = self.dataset[seq_idx]
            num_frames = len(seq_data['motion_p1']['transl'])
            if self.prob_static > 0 and random.random() < self.prob_static:
                static_frame = random.randint(0, num_frames - 1) # right end inclusive
                motion_data_p1 = seq_data['motion_p1']
                motion_data_p2 = seq_data['motion_p2']
                primitive_length = self.primitive_length
                primitive_dict = {}
                primitive_dict['person1'] = {
                    'gender': motion_data_p1['gender'],
                    'betas': motion_data_p1['betas'].expand(1, primitive_length + 1, 10),
                    'transl': motion_data_p1['transl'][[static_frame]].expand(primitive_length + 1, -1).unsqueeze(0),
                    # include one more frame for delta feature calculation
                    'global_orient':
                        motion_data_p1['global_orient'][[static_frame]].repeat(primitive_length + 1, 1, 1).unsqueeze(0),
                    'body_pose':
                        motion_data_p1['body_pose'][[static_frame]].repeat(primitive_length + 1, 1, 1, 1).unsqueeze(0),
                    'pelvis_delta': motion_data_p1['pelvis_delta'].unsqueeze(0),
                    'joints': motion_data_p1['joints'][[static_frame]].repeat(primitive_length + 1, 1, 1).unsqueeze(0),
                    'transf_rotmat': torch.eye(3).unsqueeze(0),
                    'transf_transl': torch.zeros(1, 1, 3),
                }
                primitive_dict['person2'] = {
                    'gender': motion_data_p2['gender'],
                    'betas': motion_data_p2['betas'].expand(1, primitive_length + 1, 10),
                    'transl': motion_data_p2['transl'][[static_frame]].expand(primitive_length + 1, -1).unsqueeze(0),
                    # include one more frame for delta feature calculation
                    'global_orient':
                        motion_data_p2['global_orient'][[static_frame]].repeat(primitive_length + 1, 1, 1).unsqueeze(0),
                    'body_pose':
                        motion_data_p2['body_pose'][[static_frame]].repeat(primitive_length + 1, 1, 1, 1).unsqueeze(0),
                    'pelvis_delta': motion_data_p2['pelvis_delta'].unsqueeze(0),
                    'joints': motion_data_p2['joints'][[static_frame]].repeat(primitive_length + 1, 1, 1).unsqueeze(0),
                    'transf_rotmat': torch.eye(3).unsqueeze(0),
                    'transf_transl': torch.zeros(1, 1, 3),
                }
                primitive_data = {
                    'text': '',
                    'primitive_dict': primitive_dict
                }
                primitive_data_list = [primitive_data] * self.num_primitive
                # print('get static sequenece')
            elif self.seed_only:  # only take the first primitive for predicting initial seed
                # print('get seed')
                frame_labels = []
                for seg in seq_data['frame_labels']:
                    start_frame = int(seg['start_t'] * self.target_fps)
                    end_frame = start_frame + self.primitive_length
                    if end_frame < num_frames:
                        frame_labels.append((start_frame, end_frame, seg['proc_label']))
                start_frame, end_frame, text = random.choice(frame_labels)

                motion_data_p1 = seq_data['motion_p1']
                motion_data_p2 = seq_data['motion_p2']
                primitive_dict = {}
                primitive_dict['person1'] = {
                    'gender': motion_data_p1['gender'],
                    'betas': motion_data_p1['betas'].expand(1, self.primitive_length + 1, 10),
                    'transl': motion_data_p1['transl'][start_frame:end_frame + 1].unsqueeze(0),
                    # include one more frame for delta feature calculation
                    'global_orient': motion_data_p1['global_orient'][start_frame:end_frame + 1].unsqueeze(0),
                    'body_pose': motion_data_p1['body_pose'][start_frame:end_frame + 1].unsqueeze(0),
                    'pelvis_delta': motion_data_p1['pelvis_delta'].unsqueeze(0),
                    'joints': motion_data_p1['joints'][start_frame:end_frame + 1].unsqueeze(0),
                    'transf_rotmat': torch.eye(3).unsqueeze(0),
                    'transf_transl': torch.zeros(1, 1, 3),
                }
                primitive_dict['person2'] = {
                    'gender': motion_data_p2['gender'],
                    'betas': motion_data_p2['betas'].expand(1, self.primitive_length + 1, 10),
                    'transl': motion_data_p2['transl'][start_frame:end_frame + 1].unsqueeze(0),
                    # include one more frame for delta feature calculation
                    'global_orient': motion_data_p2['global_orient'][start_frame:end_frame + 1].unsqueeze(0),
                    'body_pose': motion_data_p2['body_pose'][start_frame:end_frame + 1].unsqueeze(0),
                    'pelvis_delta': motion_data_p2['pelvis_delta'].unsqueeze(0),
                    'joints': motion_data_p2['joints'][start_frame:end_frame + 1].unsqueeze(0),
                    'transf_rotmat': torch.eye(3).unsqueeze(0),
                    'transf_transl': torch.zeros(1, 1, 3),
                }

                primitive_data_list = [
                    {
                        'text': text,
                        'primitive_dict': primitive_dict,
                    }
                ]
            else:
                if 'text' in self.weight_scheme:
                    start_frame = random.choices(range(num_frames - self.seq_length + 1), weights=seq_data['frame_weights'], k=1)[0]
                else:
                    start_frame = random.randint(0, num_frames - self.seq_length)  # [0, num_frames - seq_length], right end inclusive
                primitive_data_list = []
                for frame_idx in range(start_frame, start_frame + self.seq_length - self.primitive_length, self.future_length):
                    primitive_data = self.get_primitive(seq_data, frame_idx, frame_idx + self.primitive_length)
                    primitive_data_list.append(primitive_data)
            seq_list.append(primitive_data_list)

        # sort batch by gender
        batch = None
        for gender in ['female', 'male']:
            gender_idx = [idx for idx in range(len(seq_list)) if seq_list[idx][0]['primitive_dict']['person1']['gender'] == gender]
            if len(gender_idx) == 0:
                continue
            gender_seq_list = [seq_list[i] for i in gender_idx]
            gender_batch_size = len(gender_idx)
            gender_batch = []
            
            gender_seq_texts = None
            gender_seq_dict = None
            for primitive_idx in range(self.num_primitive):
                primitive_texts = [mp_seq[primitive_idx]['text'] for mp_seq in gender_seq_list]
                primitive_dict = {}
                for person in ['person1', 'person2']:
                    primitive_dict[person] = {}
                    primitive_dict[person]['gender'] = gender
                    for key in ['betas', 'transl', 'global_orient', 'body_pose', 'transf_rotmat', 'transf_transl', 'pelvis_delta', 'joints']:
                        primitive_dict[person][key] = torch.cat([mp_seq[primitive_idx]['primitive_dict'][person][key] for mp_seq in gender_seq_list], dim=0)
                gender_seq_texts = primitive_texts if gender_seq_texts is None else gender_seq_texts + primitive_texts
                
                if gender_seq_dict is None:
                    gender_seq_dict = primitive_dict
                else:
                    for person in ['person1', 'person2']:
                        for key in ['betas', 'transl', 'global_orient', 'body_pose', 'transf_rotmat', 'transf_transl', 'pelvis_delta', 'joints']:
                            gender_seq_dict[person][key] = torch.cat([gender_seq_dict[person][key], primitive_dict[person][key]], dim=0)

            canonicalized_primitive_dict = {}
            transf_rotmat = {}
            transf_transl = {}
            for person in ['person1', 'person2']:
                gender_seq_dict[person] = tensor_dict_to_device(gender_seq_dict[person], self.device)
                transf_rotmat[person], transf_transl[person], canonicalized_primitive_dict[person] = self.primitive_utility.canonicalize(copy.deepcopy(gender_seq_dict[person]), use_predicted_joints=True)
            if random.random() < 0.5:
                p_flag = 0
                rel_rotmat, rel_transl = self.primitive_utility.compute_rel_transform_B_in_A(transf_rotmat['person1'], transf_transl['person1'], transf_rotmat['person2'], transf_transl['person2'])
            else:
                p_flag = 1
                rel_rotmat, rel_transl = self.primitive_utility.compute_rel_transform_B_in_A(transf_rotmat['person2'], transf_transl['person2'], transf_rotmat['person1'], transf_transl['person1'])

            feature_dict = {}
            for person in ['person1', 'person2']:
                feature_dict[person] = self.primitive_utility.calc_features(canonicalized_primitive_dict[person], use_predicted_joints=True)
                feature_dict[person]['transl'] = feature_dict[person]['transl'][:, :-1, :]  # [B*num_mp, T, 3]
                feature_dict[person]['poses_6d'] = feature_dict[person]['poses_6d'][:, :-1, :]  # [B*num_mp, T, 66]
                feature_dict[person]['joints'] = feature_dict[person]['joints'][:, :-1, :]  # [B*num_mp, T, 22 * 3]

            relative_transfer = {}
            if p_flag == 0:
                relative_transfer['person1']['rel_rotmat_6d'] = transforms.matrix_to_rotation_6d(torch.eye(3, device=rel_rotmat.device).unsqueeze(0).expand(batch_size, -1, -1))
                relative_transfer['person1']['rel_transl'] = torch.zeros_like(rel_transl)
                relative_transfer['person2']['rel_rotmat_6d'] = transforms.matrix_to_rotation_6d(rel_rotmat)
                relative_transfer['person2']['rel_transl'] = rel_transl
            elif p_flag == 1:
                relative_transfer['person2']['rel_rotmat_6d'] = transforms.matrix_to_rotation_6d(torch.eye(3, device=rel_rotmat.device).unsqueeze(0).expand(batch_size, -1, -1))
                relative_transfer['person2']['rel_transl'] = torch.zeros_like(rel_transl)
                relative_transfer['person1']['rel_rotmat_6d'] = transforms.matrix_to_rotation_6d(rel_rotmat)
                relative_transfer['person1']['rel_transl'] = rel_transl
            
            if self.mode == 'sep':
                for person in ['person1', 'person2']:
                    motion_tensor_normalized = self.normalize(self.dict_to_tensor(feature_dict[person]))     # [B*num_mp, T, D]
                    motion_tensor_normalized = motion_tensor_normalized.permute(0, 2, 1).unsqueeze(2)           # [B*num_mp, D, 1, T]
                    history_mask = torch.zeros_like(motion_tensor_normalized, dtype=torch.bool, device=self.device)
                    history_mask[..., :self.cfg.history_length] = True
                    history_motion = torch.zeros_like(motion_tensor_normalized, dtype=torch.float32, device=self.device)
                    history_motion[..., :self.cfg.history_length] = motion_tensor_normalized[..., :self.cfg.history_length]

                    for primitive_idx in range(self.num_primitive):
                        start_idx = primitive_idx * gender_batch_size
                        end_idx = (primitive_idx + 1) * gender_batch_size
                        primitive_texts = gender_seq_texts[start_idx:end_idx]
                        unseen_texts = [text for text in primitive_texts if text not in self.text_embedding_dict]
                        if len(unseen_texts) > 0:
                            self.update_text_embedding_dict(unseen_texts)
                        text_embedding = torch.stack([self.text_embedding_dict[text] for text in primitive_texts], dim=0)  # [B, 512]
                        gender_batch.append(
                            {
                                'texts': primitive_texts,
                                'text_embedding': text_embedding,
                                'gender': [gender_seq_dict[person]['gender']] * gender_batch_size,
                                'betas': gender_seq_dict[person]['betas'][start_idx:end_idx, :-1, :10],
                                'motion_tensor_normalized': motion_tensor_normalized[start_idx:end_idx, ...], # [B, D, 1, T]
                                'history_motion': history_motion[start_idx:end_idx, ...],
                                'history_mask': history_mask[start_idx:end_idx, ...],
                                'history_length': self.cfg.history_length,
                                'future_length': self.cfg.future_length,
                            }
                        )
                
                selector = torch.cat([torch.ones(gender_batch_size), torch.zeros(gender_batch_size)])
                selector = selector[torch.randperm(2 * gender_batch_size)]
                
                front_group, back_group = {}, {}
                for key in add_key_list:
                    front_group[key], back_group[key] = [], []
                    for d in gender_batch[:self.num_primitive]:
                        front_group[key] += d[key]
                    for d in gender_batch[self.num_primitive:]:
                        back_group[key] += d[key]
                for key in cat_key_list:
                    front_group[key] = torch.cat([d[key] for d in gender_batch[:self.num_primitive]], dim=0)
                    back_group[key] = torch.cat([d[key] for d in gender_batch[self.num_primitive:]], dim=0)

                front_indices = torch.nonzero(selector[:gender_batch_size], as_tuple=True)[0]  
                back_indices = torch.nonzero(selector[gender_batch_size:], as_tuple=True)[0]  

                selected_batch = []
                for i in range(self.num_primitive):
                    selected_dict = {'history_length': self.cfg.history_length,
                                'future_length': self.cfg.future_length,}
                    for key in front_group.keys():    
                        if key in add_key_list:
                            selected_front = [front_group[key][i] for i in front_indices + i * gender_batch_size] 
                            selected_back = [back_group[key][i] for i in back_indices + i * gender_batch_size]
                            selected_dict[key] = selected_front + selected_back
                        elif key in cat_key_list:
                            selected_front = front_group[key][front_indices + i * gender_batch_size] 
                            selected_back = back_group[key][back_indices + i * gender_batch_size]
                            selected_dict[key] = torch.cat([selected_front, selected_back], dim=0)  
                    selected_batch.append(selected_dict)
                gender_batch = selected_batch
                
            elif self.mode == 'merged':
                # motion_tensor_normalized = self.normalize(torch.cat((self.dict_to_tensor(feature_dict['person1']), self.dict_to_tensor(feature_dict['person2'])), dim=-1))  # [B*num_mp, T, 2*D]
                motion_tensor_normalized = torch.cat((self.normalize(self.dict_to_tensor(feature_dict['person1'])), 
                                                     self.normalize(self.dict_to_tensor(feature_dict['person2']))), dim=-1)
                motion_tensor_normalized = motion_tensor_normalized.permute(0, 2, 1).unsqueeze(2)  # [B*num_mp, 2*D, 1, T]
                history_mask = torch.zeros_like(motion_tensor_normalized, dtype=torch.bool, device=self.device)
                history_mask[..., :self.cfg.history_length] = True
                history_motion = torch.zeros_like(motion_tensor_normalized, dtype=torch.float32, device=self.device)
                history_motion[..., :self.cfg.history_length] = motion_tensor_normalized[..., :self.cfg.history_length]

                for primitive_idx in range(self.num_primitive):
                    start_idx = primitive_idx * gender_batch_size
                    end_idx = (primitive_idx + 1) * gender_batch_size
                    primitive_texts = gender_seq_texts[start_idx:end_idx]
                    unseen_texts = [text for text in primitive_texts if text not in self.text_embedding_dict]
                    if len(unseen_texts) > 0:
                        self.update_text_embedding_dict(unseen_texts)
                    text_embedding = torch.stack([self.text_embedding_dict[text] for text in primitive_texts], dim=0)  # [B, 512]
                    gender_batch.append(
                        {
                            'texts': primitive_texts,
                            'text_embedding': text_embedding,
                            'gender_p1': [gender_seq_dict['person1']['gender']] * gender_batch_size,
                            'betas_p1': gender_seq_dict['person1']['betas'][start_idx:end_idx, :-1, :10],
                            'gender_p2': [gender_seq_dict['person2']['gender']] * gender_batch_size,
                            'betas_p2': gender_seq_dict['person2']['betas'][start_idx:end_idx, :-1, :10],
                            'motion_tensor_normalized': motion_tensor_normalized[start_idx:end_idx, ...], # [B, 2*D, 1, T]
                            'history_motion': history_motion[start_idx:end_idx, ...],
                            'history_mask': history_mask[start_idx:end_idx, ...],
                            'history_length': self.cfg.history_length,
                            'future_length': self.cfg.future_length,
                        }
                    )

            
            if batch is None:
                batch = gender_batch
            else:  # concatenate different gender batch
                for primitive_idx in range(self.num_primitive):
                    for key in add_key_list:
                        batch[primitive_idx][key] = batch[primitive_idx][key] + gender_batch[primitive_idx][key]
                    for key in cat_key_list:
                        batch[primitive_idx][key] = torch.cat([batch[primitive_idx][key], gender_batch[primitive_idx][key]], dim=0)

        return batch


class InterHumanMotionV2(WeightedPrimitiveSequenceDataset):
    def __init__(self, dataset_name='interhuman_single',
                 dataset_path='./data/InterHuman/seq_data_single_zero_male_fps20',
                 cfg_path='./config_files/config_hydra/motion_primitive/hml_mp_h2_f8_r4.yaml',
                 split="train",
                 device='cuda',
                 weight_scheme='uniform',
                 prob_static=0.0,
                 enforce_gender=None,
                 enforce_zero_beta=None,
                 load_data=True,
                 text_tolerance=0.0,
                 body_type='smplh',
                 seed_only=False,
                 use_frame_weights=True,
                 mode='sep', # 'sep' or 'merged'
                 text_sep = False,
                 max_segs = 20,
                 **kwargs):
        self.dataset_name = dataset_name
        self.dataset_path = dataset_path
        self.split = split
        self.device = device
        self.weight_scheme = weight_scheme
        self.prob_static = prob_static
        self.enforce_gender = enforce_gender
        self.enforce_zero_beta = enforce_zero_beta
        self.text_tolerance = text_tolerance
        self.seed_only = seed_only
        self.mode = mode
        self.text_sep = text_sep
        self.max_segs = max_segs
        print('enforce_gender: ', enforce_gender)
        print('enforce_zero_beta: ', enforce_zero_beta)

        motion_repr = {'transl': 3,
                       'poses_6d': 22 * 6,
                       'transl_delta': 3,
                       'global_orient_delta_6d': 6,
                       'joints': 22 * 3,
                       'joints_delta': 22 * 3,
                       }
        self.primitive_utility = PrimitiveUtility(device=self.device, body_type=body_type, motion_repr=motion_repr)
        self.motion_repr = self.primitive_utility.motion_repr

        # cfg_path = Path(dataset_path, 'config.yaml')
        with open(cfg_path, 'r') as f:
            self.cfg = OmegaConf.load(f)
        self.target_fps = self.cfg.fps
        # self.downsample_rate = 120 // self.target_fps
        self.history_length = self.cfg.history_length
        self.future_length = self.cfg.future_length
        self.primitive_length = self.history_length + self.future_length
        self.num_primitive = self.cfg.num_primitive
        if seed_only:
            assert self.num_primitive == 1
        self.seq_length = self.history_length + self.future_length * self.num_primitive + 1

        if load_data:
            with open(pjoin(dataset_path, f'{split}.pkl'), 'rb') as f:
                dataset = pickle.load(f)
            dataset = [data for data in dataset if len(data['motion_p1']['trans']) >= self.seq_length]
            
            elements_to_remove = ['7220', '7221', '6028', '7543', '6940', '4434', '7561', '4385']
            dataset = [data for data in dataset if data['seq_name'] not in elements_to_remove]

            for data in dataset:
                assert self.enforce_gender == data['motion_p1']['gender']
                assert self.enforce_gender == data['motion_p2']['gender']
                assert self.enforce_zero_beta
                gender_p1 = self.enforce_gender if self.enforce_gender is not None else data['motion_p1']['gender']
                gender_p2 = self.enforce_gender if self.enforce_gender is not None else data['motion_p2']['gender']
                betas_p1 =torch.from_numpy(data['motion_p1']['betas'].astype(np.float32))
                betas_p2 =torch.from_numpy(data['motion_p2']['betas'].astype(np.float32))
                if self.enforce_zero_beta:
                    betas_p1 = torch.zeros_like(betas_p1)
                    betas_p2 = torch.zeros_like(betas_p2)
                
                transl_p1 = torch.from_numpy(data['motion_p1']['trans'].astype(np.float32))
                poses_p1 = torch.from_numpy(data['motion_p1']['poses'].astype(np.float32))
                transl_p2 = torch.from_numpy(data['motion_p2']['trans'].astype(np.float32))
                poses_p2 = torch.from_numpy(data['motion_p2']['poses'].astype(np.float32))
                
                global_orient_p1 = transforms.axis_angle_to_matrix(poses_p1[:, :3])  # [T, 3, 3]
                body_pose_p1 = transforms.axis_angle_to_matrix(poses_p1[:, 3:66].reshape(-1, 21, 3))  # [T, 21, 3, 3]
                pelvis_delta_p1 = torch.from_numpy(data['motion_p1']['pelvis_delta'].astype(np.float32))  # [3]
                joints_p1 = torch.from_numpy(data['motion_p1']['joints'].astype(np.float32))  # [T, 22, 3]
                global_orient_p2 = transforms.axis_angle_to_matrix(poses_p2[:, :3])  # [T, 3, 3]
                body_pose_p2 = transforms.axis_angle_to_matrix(poses_p2[:, 3:66].reshape(-1, 21, 3))  # [T, 21, 3, 3]
                pelvis_delta_p2 = torch.from_numpy(data['motion_p2']['pelvis_delta'].astype(np.float32))  # [3]
                joints_p2 = torch.from_numpy(data['motion_p2']['joints'].astype(np.float32))  # [T, 22, 3]
                
                data['motion_p1'] = {
                    'gender': gender_p1,
                    'betas': betas_p1,
                    'transl': transl_p1,
                    'global_orient': global_orient_p1,
                    'body_pose': body_pose_p1,
                    'pelvis_delta': pelvis_delta_p1,
                    'joints': joints_p1,
                }
                data['motion_p2'] = {
                    'gender': gender_p2,
                    'betas': betas_p2,
                    'transl': transl_p2,
                    'global_orient': global_orient_p2,
                    'body_pose': body_pose_p2,
                    'pelvis_delta': pelvis_delta_p2,
                    'joints': joints_p2,
                }
            print('num of sequences: ', len(dataset))
            # assign sampling weights to each sequence


            for data in dataset:
                if 'uniform' in weight_scheme:
                    data['weight'] = 1.0
                elif 'length' in weight_scheme:
                    data['weight'] = len(data['motion_p1']['trans'])
            print('finish first assigning seq weights')

            # overfit using one sequence
            # if 'overfit' in weight_scheme:
            #     seq_id = int(weight_scheme.split('overfit:')[-1].split('_')[0])
            #     for idx, data in enumerate(dataset):
            #         if idx == seq_id:
            #             data['weight'] = 1.0
            #         else:
            #             data['weight'] = 0.0

            seq_weights = np.array([data['weight'] for data in dataset])
            seq_weights = seq_weights / seq_weights.sum()

            self.dataset = dataset
            self.seq_weights = seq_weights

        # load or calc mean and std
        self.tensor_mean_device_dict = {}
        file_name = f'mean_std_h{self.history_length}_f{self.future_length}'
        
        mean_std_path = Path(dataset_path, f'{file_name}.pkl')
        if mean_std_path.exists():
            print(f'loading mean and std from {mean_std_path}')
            with open(mean_std_path, 'rb') as f:
                self.tensor_mean, self.tensor_std = pickle.load(f)  # [1, 1, D]
        else:
            assert self.split == 'train'
            print('calculating mean and std using train split')
            self.tensor_mean, self.tensor_std = self.calc_mean_std()
            with open(mean_std_path, 'wb') as f:
                pickle.dump((self.tensor_mean.detach().cpu(), self.tensor_std.detach().cpu()), f)

        # load clip model, get train text embeddings
        self.clip_model = load_and_freeze_clip(clip_version='ViT-B/32', device=self.device)
        self.embedding_path = {}
        embedding_path = {}
        for person in ['person1', 'person2']:
            if text_sep:
                self.embedding_path[person] = embedding_path[person] = Path(dataset_path, f'{split}_{person}_text_embedding_dict_single_textsep.pkl')
            else:
                self.embedding_path[person] = embedding_path[person] = Path(dataset_path, f'{split}_{person}_text_embedding_dict_single.pkl')
        self.text_embedding_dict = {}
        if text_sep:
            self.text_mask_dict = {}
            
        if embedding_path['person1'].exists() and embedding_path['person2'].exists():
            print(f"loading text1 embeddings from {embedding_path['person1']}, loading text2 embeddings from {embedding_path['person2']}")
            for person in ['person1', 'person2']:
                with open(embedding_path[person], 'rb') as f:
                    self.text_embedding_dict[person] = pickle.load(f)
                if text_sep:
                    with open(Path(str(embedding_path[person]).replace('text_embedding_dict', 'text_mask_dict')), 'rb') as f:
                        self.text_mask_dict[person] = pickle.load(f)
        else:
            print('calculating text embeddings')
            raw_texts = {'person1':[], 'person2': []}
            for data in self.dataset:
                if 'frame_labels_p1' in data and 'frame_labels_p2' in data:
                    raw_texts['person1'].extend([seg['proc_label'] for seg in data['frame_labels_p1']])
                    raw_texts['person2'].extend([seg['proc_label'] for seg in data['frame_labels_p2']])
            raw_texts['person1'] = list(set(raw_texts['person1']))
            raw_texts['person2'] = list(set(raw_texts['person2']))
            num_texts = {
                'person1': len(raw_texts['person1']),
                'person2': len(raw_texts['person2']),
            }
            print('num of unique texts_p1: ', len(raw_texts['person1']))
            print('num of unique texts_p2: ', len(raw_texts['person2']))
            # get text embeddings by batch
            text_embeddings = {'person1':[], 'person2': []}
            text_mask = {'person1':[], 'person2': []}
            for person in ['person1', 'person2']:
                batch_start_idx = 0
                while batch_start_idx < num_texts[person]:
                    batch_end_idx = min(batch_start_idx + 256, num_texts[person])
                    text_embeddings_temp = encode_text(self.clip_model, raw_texts[person][batch_start_idx:batch_end_idx], text_sep=text_sep, max_segs=max_segs)
                    if text_sep:
                        text_embeddings[person].append(text_embeddings_temp[0])
                        text_mask[person].append(text_embeddings_temp[1])
                    else:
                        text_embeddings[person].append(text_embeddings_temp)
                    batch_start_idx = batch_end_idx
                text_embeddings[person] = torch.cat(text_embeddings[person], dim=0).detach().cpu().numpy()
            
                self.text_embedding_dict[person] = {raw_texts[person][idx]: text_embeddings[person][idx] for idx in range(num_texts[person])}
                self.text_embedding_dict[person][''] = np.zeros(512).astype(np.float32)  # for empty text have zero embedding, compatible with mdm text masking
                with open(embedding_path[person], 'wb') as f:
                    pickle.dump(self.text_embedding_dict[person], f)
                if text_sep:
                    text_mask[person] = torch.cat(text_mask[person], dim=0).detach().cpu().numpy()
                    self.text_mask_dict[person] = {raw_texts[person][idx]: text_mask[person][idx] for idx in range(num_texts[person])}
                    self.text_mask_dict[person][''] = np.zeros(max_segs).astype(np.bool_)  # for empty text have zero embedding, compatible with mdm text masking
                    with open(Path(str(embedding_path[person]).replace('text_embedding_dict', 'text_mask_dict')), 'wb') as f:
                        pickle.dump(self.text_mask_dict[person], f)
        for person in ['person1', 'person2']:
            for key in self.text_embedding_dict[person]:
                self.text_embedding_dict[person][key] = torch.from_numpy(self.text_embedding_dict[person][key]).to(dtype=torch.float32, device=self.device)
                if text_sep:
                    self.text_mask_dict[person][key] = torch.from_numpy(self.text_mask_dict[person][key]).to(dtype=torch.bool, device=self.device)

    def update_text_embedding_dict(self, new_texts, person, text_sep=False, max_segs=20):
        new_text_embeddings = encode_text(self.clip_model, new_texts, text_sep=text_sep, max_segs=max_segs)
        for idx, text in enumerate(new_texts):
            if text_sep:
                self.text_embedding_dict[person][text] = new_text_embeddings[idx][0]
                self.text_mask_dict[person][text] = new_text_embeddings[idx][1]
            else:
                self.text_embedding_dict[person][text] = new_text_embeddings[idx]

    def calc_mean_std(self, batch_size=512):
        if self.future_length == 1:
            batch_size = min(batch_size, 64)
        all_mp_data = []
        for seq_data in self.dataset:
            motion_data_p1 = seq_data['motion_p1']
            num_frames = motion_data_p1['transl'].shape[0]
            primitive_data_list = []
            for start_frame in range(0, num_frames - self.primitive_length, self.future_length):
                end_frame = start_frame + self.primitive_length
                primitive_data_list.append(self.get_primitive(seq_data, start_frame, end_frame, skip_text=True))
                
            primitive_dict = {}
            for person in ['person1', 'person2']:
                primitive_dict[person] = {}
                primitive_dict[person]['gender'] = {primitive_data_list[0]['primitive_dict'][person]['gender']}
                for key in ['betas', 'transl', 'global_orient', 'body_pose', 'transf_rotmat', 'transf_transl', 'pelvis_delta', 'joints']:
                    primitive_dict[person][key] = torch.cat([data['primitive_dict'][person][key] for data in primitive_data_list], dim=0)
                primitive_dict[person] = tensor_dict_to_device(primitive_dict[person], self.device)

            # split primitive_dict into batches
            batch_start_idx = 0
            while batch_start_idx < len(primitive_dict['person1']['transl']):
                batch_primitive_dict = {}
                canonicalized_primitive_dict = {}
                batch_end_idx = min(batch_start_idx + batch_size, len(primitive_dict['person1']['transl']))
                for person in ['person1', 'person2']:
                    batch_primitive_dict[person] = {}
                    batch_primitive_dict[person] = {key: primitive_dict[person][key][batch_start_idx:batch_end_idx] for key in ['betas', 'transl', 'global_orient', 'body_pose', 'transf_rotmat', 'transf_transl', 'pelvis_delta', 'joints']}
                    batch_primitive_dict[person]['gender'] = primitive_dict[person]['gender']
                    # transf_rotmat[person], transf_transl[person], canonicalized_primitive_dict[person] = self.primitive_utility.canonicalize(batch_primitive_dict[person], use_predicted_joints=True)
                    _, _, canonicalized_primitive_dict[person] = self.primitive_utility.canonicalize(batch_primitive_dict[person], use_predicted_joints=True)

                feature_dict = {}
                motion_tensor = {}
                for person in ['person1', 'person2']:
                    feature_dict[person] = self.primitive_utility.calc_features(canonicalized_primitive_dict[person], use_predicted_joints=True)
                    feature_dict[person]['transl'] = feature_dict[person]['transl'][:, :-1, :]          # [num_primitive, T, 3]
                    feature_dict[person]['poses_6d'] = feature_dict[person]['poses_6d'][:, :-1, :]      # [num_primitive, T, 66]
                    feature_dict[person]['joints'] = feature_dict[person]['joints'][:, :-1, :]          # [num_primitive, T, 22 * 3]
                    # feature_dict[person]['transf_rotmat_6d'] = transforms.matrix_to_rotation_6d(transf_rotmat[person])
                    # feature_dict[person]['transf_transl'] = transf_transl[person]
                    motion_tensor[person] = self.dict_to_tensor(feature_dict[person]).detach().cpu()    # [num_primitive, T, D]
                    all_mp_data.append(motion_tensor[person])                                           # [num_primitive, T, D]

                batch_start_idx = batch_end_idx

        all_mp_data = torch.cat(all_mp_data, dim=0)                 # [2*N, T, D]
        tensor_mean = all_mp_data.mean(dim=[0, 1], keepdim=True)    # [1, 1, D]
        tensor_std = all_mp_data.std(dim=[0, 1], keepdim=True)      # [1, 1, D]
        return tensor_mean.to(self.device), tensor_std.to(self.device)

    def get_primitive(self, seq_data, start_frame, end_frame, skip_text=False):
        """end_frame included"""
        motion_data_p1 = seq_data['motion_p1']
        motion_data_p2 = seq_data['motion_p2']
        primitive_dict = {}
        primitive_dict['person1'] = {
            'gender': motion_data_p1['gender'],
            'betas': motion_data_p1['betas'].expand(1, self.primitive_length + 1, 10),
            'transl': motion_data_p1['transl'][start_frame:end_frame + 1].unsqueeze(0),  # include one more frame for delta feature calculation
            'global_orient': motion_data_p1['global_orient'][start_frame:end_frame + 1].unsqueeze(0),
            'body_pose': motion_data_p1['body_pose'][start_frame:end_frame + 1].unsqueeze(0),
            'pelvis_delta': motion_data_p1['pelvis_delta'].unsqueeze(0),
            'joints': motion_data_p1['joints'][start_frame:end_frame + 1].unsqueeze(0),
            'transf_rotmat': torch.eye(3).unsqueeze(0),
            'transf_transl': torch.zeros(1, 1, 3),
        }
        primitive_dict['person2'] = {
            'gender': motion_data_p2['gender'],
            'betas': motion_data_p2['betas'].expand(1, self.primitive_length + 1, 10),
            'transl': motion_data_p2['transl'][start_frame:end_frame + 1].unsqueeze(0),  # include one more frame for delta feature calculation
            'global_orient': motion_data_p2['global_orient'][start_frame:end_frame + 1].unsqueeze(0),
            'body_pose': motion_data_p2['body_pose'][start_frame:end_frame + 1].unsqueeze(0),
            'pelvis_delta': motion_data_p2['pelvis_delta'].unsqueeze(0),
            'joints': motion_data_p2['joints'][start_frame:end_frame + 1].unsqueeze(0),
            'transf_rotmat': torch.eye(3).unsqueeze(0),
            'transf_transl': torch.zeros(1, 1, 3),
        }

        texts_p1 = []
        texts_p2 = []
        if not skip_text and 'frame_labels_p1' in seq_data and 'frame_labels_p2' in seq_data:
            future_start = (start_frame + self.history_length) / self.target_fps
            future_end = (start_frame + self.history_length + self.future_length - 1) / self.target_fps
            # print('text tolerance: ', self.text_tolerance)
            for seg in seq_data['frame_labels_p1']:
                if have_overlap([seg['start_t'], seg['end_t']], [future_start - self.text_tolerance, future_end + self.text_tolerance]):
                    texts_p1.append(seg['proc_label'])
            for seg in seq_data['frame_labels_p2']:
                if have_overlap([seg['start_t'], seg['end_t']], [future_start - self.text_tolerance, future_end + self.text_tolerance]):
                    texts_p2.append(seg['proc_label'])

        output = {
            'text_p1': random.choice(texts_p1) if len(texts_p1) > 0 else '',
            'text_p2': random.choice(texts_p2) if len(texts_p2) > 0 else '',
            'primitive_dict': primitive_dict,
        }
        return output

    def get_batch(self, batch_size=8):
        self.time = time.time()
        seq_list = []
        batch_idx = self.get_batch_idx(batch_size)
        add_key_list = ['texts', 'gender']
        cat_key_list = ['betas', 'motion_tensor_normalized', 'history_motion', 'history_mask', 'text_embedding']
        if self.text_sep:
            cat_key_list.append('text_mask')
        
        for seq_idx in batch_idx:
            seq_data = self.dataset[seq_idx]
            num_frames = len(seq_data['motion_p1']['transl'])
            if 'text' in self.weight_scheme:
                start_frame = random.choices(range(num_frames - self.seq_length + 1), weights=seq_data['frame_weights'], k=1)[0]
            else:
                start_frame = random.randint(0, num_frames - self.seq_length)  # [0, num_frames - seq_length], right end inclusive
            primitive_data_list = []
            for frame_idx in range(start_frame, start_frame + self.seq_length - self.primitive_length, self.future_length):
                primitive_data = self.get_primitive(seq_data, frame_idx, frame_idx + self.primitive_length)
                primitive_data_list.append(primitive_data)
            seq_list.append(primitive_data_list)

        # sort batch by gender
        batch = None
        for gender in ['female', 'male']:
            gender_idx = [idx for idx in range(len(seq_list)) if seq_list[idx][0]['primitive_dict']['person1']['gender'] == gender]
            if len(gender_idx) == 0:
                continue
            gender_seq_list = [seq_list[i] for i in gender_idx]
            gender_batch_size = len(gender_idx)
            gender_batch = []
            
            gender_seq_texts = {
                'person1': None,
                'person2': None,
            }
            gender_seq_dict = None
            for primitive_idx in range(self.num_primitive):
                primitive_texts_p1 = [mp_seq[primitive_idx]['text_p1'] for mp_seq in gender_seq_list]
                primitive_texts_p2 = [mp_seq[primitive_idx]['text_p2'] for mp_seq in gender_seq_list]
                primitive_dict = {}
                for person in ['person1', 'person2']:
                    primitive_dict[person] = {}
                    primitive_dict[person]['gender'] = gender
                    for key in ['betas', 'transl', 'global_orient', 'body_pose', 'transf_rotmat', 'transf_transl', 'pelvis_delta', 'joints']:
                        primitive_dict[person][key] = torch.cat([mp_seq[primitive_idx]['primitive_dict'][person][key] for mp_seq in gender_seq_list], dim=0)
                gender_seq_texts['person1'] = primitive_texts_p1 if gender_seq_texts['person1'] is None else gender_seq_texts['person1'] + primitive_texts_p1
                gender_seq_texts['person2'] = primitive_texts_p2 if gender_seq_texts['person2'] is None else gender_seq_texts['person2'] + primitive_texts_p2
                
                if gender_seq_dict is None:
                    gender_seq_dict = primitive_dict
                else:
                    for person in ['person1', 'person2']:
                        for key in ['betas', 'transl', 'global_orient', 'body_pose', 'transf_rotmat', 'transf_transl', 'pelvis_delta', 'joints']:
                            gender_seq_dict[person][key] = torch.cat([gender_seq_dict[person][key], primitive_dict[person][key]], dim=0)

            canonicalized_primitive_dict = {}
            for person in ['person1', 'person2']:
                gender_seq_dict[person] = tensor_dict_to_device(gender_seq_dict[person], self.device)
                _, _, canonicalized_primitive_dict[person] = self.primitive_utility.canonicalize(copy.deepcopy(gender_seq_dict[person]), use_predicted_joints=True)

            feature_dict = {}
            for person in ['person1', 'person2']:
                feature_dict[person] = self.primitive_utility.calc_features(canonicalized_primitive_dict[person], use_predicted_joints=True)
                feature_dict[person]['transl'] = feature_dict[person]['transl'][:, :-1, :]  # [B*num_mp, T, 3]
                feature_dict[person]['poses_6d'] = feature_dict[person]['poses_6d'][:, :-1, :]  # [B*num_mp, T, 66]
                feature_dict[person]['joints'] = feature_dict[person]['joints'][:, :-1, :]  # [B*num_mp, T, 22 * 3]
            
                motion_tensor_normalized = self.normalize(self.dict_to_tensor(feature_dict[person]))     # [B*num_mp, T, D]
                motion_tensor_normalized = motion_tensor_normalized.permute(0, 2, 1).unsqueeze(2)           # [B*num_mp, D, 1, T]
                history_mask = torch.zeros_like(motion_tensor_normalized, dtype=torch.bool, device=self.device)
                history_mask[..., :self.cfg.history_length] = True
                history_motion = torch.zeros_like(motion_tensor_normalized, dtype=torch.float32, device=self.device)
                history_motion[..., :self.cfg.history_length] = motion_tensor_normalized[..., :self.cfg.history_length]

                for primitive_idx in range(self.num_primitive):
                    start_idx = primitive_idx * gender_batch_size
                    end_idx = (primitive_idx + 1) * gender_batch_size
                    primitive_texts = gender_seq_texts[person][start_idx:end_idx]
                    unseen_texts = [text for text in primitive_texts if text not in self.text_embedding_dict[person]]
                    if len(unseen_texts) > 0:
                        self.update_text_embedding_dict(unseen_texts, person, text_sep=self.text_sep, max_segs=self.max_segs)
                        # new_text_embeddings = encode_text(self.clip_model, unseen_texts)
                        # for idx, text in enumerate(unseen_texts):
                        #     self.text_embedding_dict[person][text] = new_text_embeddings[idx]
                    text_embedding = torch.stack([self.text_embedding_dict[person][text] for text in primitive_texts], dim=0)  # [B, 512]
                    if self.text_sep:
                        text_mask = torch.stack([self.text_mask_dict[person][text] for text in primitive_texts], dim=0)  # [B, max_segs]
                    else:
                        text_mask = None
                    gender_batch.append(
                        {
                            'texts': primitive_texts,
                            'text_embedding': text_embedding,
                            'text_mask': text_mask, 
                            'gender': [gender_seq_dict[person]['gender']] * gender_batch_size,
                            'betas': gender_seq_dict[person]['betas'][start_idx:end_idx, :-1, :10],
                            'motion_tensor_normalized': motion_tensor_normalized[start_idx:end_idx, ...], # [B, D, 1, T]
                            'history_motion': history_motion[start_idx:end_idx, ...],
                            'history_mask': history_mask[start_idx:end_idx, ...],
                            'history_length': self.cfg.history_length,
                            'future_length': self.cfg.future_length,
                        }
                    )
            
            selector = torch.cat([torch.ones(gender_batch_size), torch.zeros(gender_batch_size)])
            selector = selector[torch.randperm(2 * gender_batch_size)]
            
            front_group, back_group = {}, {}
            for key in add_key_list:
                front_group[key], back_group[key] = [], []
                for d in gender_batch[:self.num_primitive]:
                    front_group[key] += d[key]
                for d in gender_batch[self.num_primitive:]:
                    back_group[key] += d[key]
            for key in cat_key_list:
                front_group[key] = torch.cat([d[key] for d in gender_batch[:self.num_primitive]], dim=0)
                back_group[key] = torch.cat([d[key] for d in gender_batch[self.num_primitive:]], dim=0)

            front_indices = torch.nonzero(selector[:gender_batch_size], as_tuple=True)[0]  
            back_indices = torch.nonzero(selector[gender_batch_size:], as_tuple=True)[0]  

            selected_batch = []
            for i in range(self.num_primitive):
                selected_dict = {'history_length': self.cfg.history_length,
                            'future_length': self.cfg.future_length,}
                for key in front_group.keys():    
                    if key in add_key_list:
                        selected_front = [front_group[key][i] for i in front_indices + i * gender_batch_size] 
                        selected_back = [back_group[key][i] for i in back_indices + i * gender_batch_size]
                        selected_dict[key] = selected_front + selected_back
                    elif key in cat_key_list:
                        selected_front = front_group[key][front_indices + i * gender_batch_size] 
                        selected_back = back_group[key][back_indices + i * gender_batch_size]
                        selected_dict[key] = torch.cat([selected_front, selected_back], dim=0)  
                selected_batch.append(selected_dict)
            gender_batch = selected_batch
                            
            if batch is None:
                batch = gender_batch
            else:  # concatenate different gender batch
                for primitive_idx in range(self.num_primitive):
                    for key in add_key_list:
                        batch[primitive_idx][key] = batch[primitive_idx][key] + gender_batch[primitive_idx][key]
                    for key in cat_key_list:
                        batch[primitive_idx][key] = torch.cat([batch[primitive_idx][key], gender_batch[primitive_idx][key]], dim=0)
        return batch
    
    def get_item(self, idx):
        seq_data = self.dataset[idx]
        num_frames = len(seq_data['motion_p1']['transl'])
        gender = {}
        gender['person1'] = seq_data['motion_p1']['gender']
        gender['person2'] = seq_data['motion_p2']['gender']
        if 'text' in self.weight_scheme:
            start_frame = random.choices(range(num_frames - self.seq_length + 1), weights=seq_data['frame_weights'], k=1)[0]
        else:
            start_frame = random.randint(0, num_frames - self.seq_length)  # [0, num_frames - seq_length], right end inclusive
        primitive_data_list = []
        for frame_idx in range(start_frame, start_frame + self.seq_length - self.primitive_length, self.future_length):
            primitive_data = self.get_primitive(seq_data, frame_idx, frame_idx + self.primitive_length)
            primitive_data_list.append(primitive_data)
        
        gender_seq_texts = {
            'person1': None,
            'person2': None,
        }
        gender_seq_dict = None
        for primitive_idx in range(self.num_primitive):
            primitive_texts_p1 = primitive_data_list[primitive_idx]['text_p1']
            primitive_texts_p2 = primitive_data_list[primitive_idx]['text_p2']
            primitive_dict = {}
            for person in ['person1', 'person2']:
                primitive_dict[person] = {}
                primitive_dict[person]['gender'] = gender[person]
                for key in ['betas', 'transl', 'global_orient', 'body_pose', 'transf_rotmat', 'transf_transl', 'pelvis_delta', 'joints']:
                    primitive_dict[person][key] = primitive_data_list[primitive_idx]['primitive_dict'][person][key]
            gender_seq_texts['person1'] = primitive_texts_p1 if gender_seq_texts['person1'] is None else gender_seq_texts['person1'] + primitive_texts_p1
            gender_seq_texts['person2'] = primitive_texts_p2 if gender_seq_texts['person2'] is None else gender_seq_texts['person2'] + primitive_texts_p2
            
            if gender_seq_dict is None:
                gender_seq_dict = primitive_dict
            else:
                for person in ['person1', 'person2']:
                    for key in ['betas', 'transl', 'global_orient', 'body_pose', 'transf_rotmat', 'transf_transl', 'pelvis_delta', 'joints']:
                        gender_seq_dict[person][key] = torch.cat([gender_seq_dict[person][key], primitive_dict[person][key]], dim=0)

        canonicalized_primitive_dict = {}
        for person in ['person1', 'person2']:
            gender_seq_dict[person] = tensor_dict_to_device(gender_seq_dict[person], self.device)
            _, _, canonicalized_primitive_dict[person] = self.primitive_utility.canonicalize(copy.deepcopy(gender_seq_dict[person]), use_predicted_joints=True)

        feature_dict = {}
        data_batch = []
        for person in ['person1', 'person2']:
            feature_dict[person] = self.primitive_utility.calc_features(canonicalized_primitive_dict[person], use_predicted_joints=True)
            feature_dict[person]['transl'] = feature_dict[person]['transl'][:, :-1, :]  # [num_mp, T, 3]
            feature_dict[person]['poses_6d'] = feature_dict[person]['poses_6d'][:, :-1, :]  # [num_mp, T, 66]
            feature_dict[person]['joints'] = feature_dict[person]['joints'][:, :-1, :]  # [num_mp, T, 22 * 3]
        
            motion_tensor_normalized = self.normalize(self.dict_to_tensor(feature_dict[person]))     # [num_mp, T, D]
            motion_tensor_normalized = motion_tensor_normalized.permute(0, 2, 1).unsqueeze(2)           # [num_mp, D, 1, T]
            history_mask = torch.zeros_like(motion_tensor_normalized, dtype=torch.bool, device=self.device)
            history_mask[..., :self.cfg.history_length] = True
            history_motion = torch.zeros_like(motion_tensor_normalized, dtype=torch.float32, device=self.device)
            history_motion[..., :self.cfg.history_length] = motion_tensor_normalized[..., :self.cfg.history_length]

            for primitive_idx in range(self.num_primitive):
                start_idx = primitive_idx
                end_idx = primitive_idx + 1
                primitive_texts = gender_seq_texts[person][start_idx:end_idx]
                unseen_texts = [text for text in primitive_texts if text not in self.text_embedding_dict[person]]
                if len(unseen_texts) > 0:
                    new_text_embeddings = encode_text(self.clip_model, unseen_texts)
                    for idx, text in enumerate(unseen_texts):
                        self.text_embedding_dict[person][text] = new_text_embeddings[idx]
                text_embedding = torch.stack([self.text_embedding_dict[person][text] for text in primitive_texts], dim=0)  # [B, 512]
                data_batch.append(
                    {
                        'texts': primitive_texts,
                        'text_embedding': text_embedding,
                        'gender': [gender_seq_dict[person]['gender']],
                        'betas': gender_seq_dict[person]['betas'][start_idx:end_idx, :-1, :10],
                        'motion_tensor_normalized': motion_tensor_normalized[start_idx:end_idx, ...], # [1, D, 1, T]
                        'history_motion': history_motion[start_idx:end_idx, ...],
                        'history_mask': history_mask[start_idx:end_idx, ...],
                        'history_length': self.cfg.history_length,
                        'future_length': self.cfg.future_length,
                    }
                )
        if random.random() < 0.5:
            return data_batch[:self.num_primitive]
        else:
            return data_batch[self.num_primitive:]

# dataset = InterHumanMotionV2(enforce_gender='male',
#                             enforce_zero_beta=1,
#                             device='cuda:3',
#                             mode='merged',
#                             text_encoder='clip',
#                             text_sep=False,
#                             split='test')

# # batch_test = dataset.get_batch(batch_size=2)
# data = dataset.get_item(0)


class InterHumanDataset(WeightedPrimitiveSequenceDataset):
    def __init__(self, dataset_name='interhuman',
                 dataset_path='./data/InterHuman/seq_data_zero_male',
                 cfg_path='./config_files/config_hydra/motion_primitive/hml_mp_h2_f8_r4.yaml',
                 split="train",
                 device='cuda',
                 weight_scheme='uniform',
                 prob_static=0.0,
                 enforce_gender=None,
                 enforce_zero_beta=None,
                 load_data=True,
                 text_tolerance=0.0,
                 body_type='smplh',
                 seed_only=False,
                 use_frame_weights=True,
                 **kwargs):
        self.dataset_name = dataset_name
        self.dataset_path = dataset_path
        self.split = split
        self.device = device
        self.weight_scheme = weight_scheme
        self.prob_static = prob_static
        self.enforce_gender = enforce_gender
        self.enforce_zero_beta = enforce_zero_beta
        self.text_tolerance = text_tolerance
        self.seed_only = seed_only
        print('enforce_gender: ', enforce_gender)
        print('enforce_zero_beta: ', enforce_zero_beta)

        motion_repr = kwargs.get('motion_repr', 
                                 {'transl': 3,
                                  'poses_6d': 22 * 6,
                                  'transl_delta': 3,
                                  'global_orient_delta_6d': 6,
                                  'joints': 22 * 3,
                                  'joints_delta': 22 * 3,
                                  })
        self.primitive_utility = PrimitiveUtility(device=self.device, body_type=body_type, motion_repr=motion_repr)
        self.motion_repr = self.primitive_utility.motion_repr

        # cfg_path = Path(dataset_path, 'config.yaml')
        with open(cfg_path, 'r') as f:
            self.cfg = OmegaConf.load(f)
        self.target_fps = self.cfg.fps
        # self.downsample_rate = 120 // self.target_fps
        self.history_length = self.cfg.history_length
        self.future_length = self.cfg.future_length
        self.primitive_length = self.history_length + self.future_length
        self.num_primitive = self.cfg.num_primitive
        if seed_only:
            assert self.num_primitive == 1
        self.seq_length = self.history_length + self.future_length * self.num_primitive + 1

        if load_data:
            with open(pjoin(dataset_path, f'{split}.pkl'), 'rb') as f:
                dataset = pickle.load(f)
            dataset = [data for data in dataset if len(data['motion_p1']['trans']) >= self.seq_length]

            for data in dataset:
                assert self.enforce_gender == data['motion_p1']['gender']
                assert self.enforce_gender == data['motion_p2']['gender']
                assert self.enforce_zero_beta
                gender_p1 = self.enforce_gender if self.enforce_gender is not None else data['motion_p1']['gender']
                gender_p2 = self.enforce_gender if self.enforce_gender is not None else data['motion_p2']['gender']
                betas_p1 =torch.from_numpy(data['motion_p1']['betas'].astype(np.float32))
                betas_p2 =torch.from_numpy(data['motion_p2']['betas'].astype(np.float32))
                if self.enforce_zero_beta:
                    betas_p1 = torch.zeros_like(betas_p1)
                    betas_p2 = torch.zeros_like(betas_p2)
                
                transl_p1 = torch.from_numpy(data['motion_p1']['trans'].astype(np.float32))
                poses_p1 = torch.from_numpy(data['motion_p1']['poses'].astype(np.float32))
                transl_p2 = torch.from_numpy(data['motion_p2']['trans'].astype(np.float32))
                poses_p2 = torch.from_numpy(data['motion_p2']['poses'].astype(np.float32))
                
                global_orient_p1 = transforms.axis_angle_to_matrix(poses_p1[:, :3])  # [T, 3, 3]
                body_pose_p1 = transforms.axis_angle_to_matrix(poses_p1[:, 3:66].reshape(-1, 21, 3))  # [T, 21, 3, 3]
                pelvis_delta_p1 = torch.from_numpy(data['motion_p1']['pelvis_delta'].astype(np.float32))  # [3]
                joints_p1 = torch.from_numpy(data['motion_p1']['joints'].astype(np.float32))  # [T, 22, 3]
                global_orient_p2 = transforms.axis_angle_to_matrix(poses_p2[:, :3])  # [T, 3, 3]
                body_pose_p2 = transforms.axis_angle_to_matrix(poses_p2[:, 3:66].reshape(-1, 21, 3))  # [T, 21, 3, 3]
                pelvis_delta_p2 = torch.from_numpy(data['motion_p2']['pelvis_delta'].astype(np.float32))  # [3]
                joints_p2 = torch.from_numpy(data['motion_p2']['joints'].astype(np.float32))  # [T, 22, 3]
                
                data['motion_p1'] = {
                    'gender': gender_p1,
                    'betas': betas_p1,
                    'transl': transl_p1,
                    'global_orient': global_orient_p1,
                    'body_pose': body_pose_p1,
                    'pelvis_delta': pelvis_delta_p1,
                    'joints': joints_p1,
                }
                data['motion_p2'] = {
                    'gender': gender_p2,
                    'betas': betas_p2,
                    'transl': transl_p2,
                    'global_orient': global_orient_p2,
                    'body_pose': body_pose_p2,
                    'pelvis_delta': pelvis_delta_p2,
                    'joints': joints_p2,
                }
            print('num of sequences: ', len(dataset))
            # assign sampling weights to each sequence


            for data in dataset:
                if 'uniform' in weight_scheme:
                    data['weight'] = 1.0
                elif 'length' in weight_scheme:
                    data['weight'] = len(data['motion_p1']['trans'])
            print('finish first assigning seq weights')

            # overfit using one sequence
            # if 'overfit' in weight_scheme:
            #     seq_id = int(weight_scheme.split('overfit:')[-1].split('_')[0])
            #     for idx, data in enumerate(dataset):
            #         if idx == seq_id:
            #             data['weight'] = 1.0
            #         else:
            #             data['weight'] = 0.0

            seq_weights = np.array([data['weight'] for data in dataset])
            seq_weights = seq_weights / seq_weights.sum()

            self.dataset = dataset
            self.seq_weights = seq_weights

        # load or calc mean and std
        self.tensor_mean_device_dict = {}
        file_name = f'mean_std_h{self.history_length}_f{self.future_length}'
        # TODO: use different mean and std when enforce gender and beta
        # if self.enforce_gender is not None:
        #     file_name = file_name + f'_{self.enforce_gender}'
        # if self.enforce_zero_beta:
        #     file_name = file_name + '_zero_beta'
        mean_std_path = Path(dataset_path, f'{file_name}.pkl')
        if mean_std_path.exists():
            print(f'loading mean and std from {mean_std_path}')
            with open(mean_std_path, 'rb') as f:
                self.tensor_mean, self.tensor_std = pickle.load(f)  # [1, 1, D]
        else:
            assert self.split == 'train'
            print('calculating mean and std using train split')
            self.tensor_mean, self.tensor_std = self.calc_mean_std()
            with open(mean_std_path, 'wb') as f:
                pickle.dump((self.tensor_mean.detach().cpu(), self.tensor_std.detach().cpu()), f)

        # load clip model, get train text embeddings
        self.clip_model = load_and_freeze_clip(clip_version='ViT-B/32', device=self.device)
        self.embedding_path = embedding_path = Path(dataset_path, f'{split}_text_embedding_dict.pkl')
        if embedding_path.exists():
            print(f'loading text embeddings from {embedding_path}')
            with open(embedding_path, 'rb') as f:
                self.text_embedding_dict = pickle.load(f)
        else:
            print('calculating text embeddings')
            raw_texts = []
            for data in self.dataset:
                if 'frame_labels' in data:
                    raw_texts.extend([seg['proc_label'] for seg in data['frame_labels']])
            raw_texts = list(set(raw_texts))
            num_texts = len(raw_texts)
            print('num of unique texts: ', len(raw_texts))
            # get text embeddings by batch
            text_embeddings = []
            batch_start_idx = 0
            while batch_start_idx < num_texts:
                batch_end_idx = min(batch_start_idx + 256, num_texts)
                text_embeddings.append(encode_text(self.clip_model, raw_texts[batch_start_idx:batch_end_idx]))
                batch_start_idx = batch_end_idx
            text_embeddings = torch.cat(text_embeddings, dim=0).detach().cpu().numpy()
            print(text_embeddings.shape)
            self.text_embedding_dict = {raw_texts[idx]: text_embeddings[idx] for idx in range(num_texts)}
            self.text_embedding_dict[''] = np.zeros(512).astype(np.float32)  # for empty text have zero embedding, compatible with mdm text masking
            with open(embedding_path, 'wb') as f:
                pickle.dump(self.text_embedding_dict, f)
        for key in self.text_embedding_dict:
            self.text_embedding_dict[key] = torch.from_numpy(self.text_embedding_dict[key]).to(dtype=torch.float32, device=self.device)

    def calc_mean_std(self, batch_size=512):
        if self.future_length == 1:
            batch_size = min(batch_size, 64)
        all_mp_data = []
        for seq_data in self.dataset:
            motion_data_p1 = seq_data['motion_p1']
            num_frames = motion_data_p1['transl'].shape[0]
            primitive_data_list = []
            for start_frame in range(0, num_frames - self.primitive_length, self.future_length):
                end_frame = start_frame + self.primitive_length
                primitive_data_list.append(self.get_primitive(seq_data, start_frame, end_frame, skip_text=True))
                
            primitive_dict = {}
            for person in ['person1', 'person2']:
                primitive_dict[person] = {}
                primitive_dict[person]['gender'] = {primitive_data_list[0]['primitive_dict'][person]['gender']}
                for key in ['betas', 'transl', 'global_orient', 'body_pose', 'transf_rotmat', 'transf_transl', 'pelvis_delta', 'joints']:
                    primitive_dict[person][key] = torch.cat([data['primitive_dict'][person][key] for data in primitive_data_list], dim=0)
                primitive_dict[person] = tensor_dict_to_device(primitive_dict[person], self.device)

            # split primitive_dict into batches
            batch_start_idx = 0
            while batch_start_idx < len(primitive_dict['person1']['transl']):
                batch_primitive_dict = {}
                canonicalized_primitive_dict = {}
                batch_end_idx = min(batch_start_idx + batch_size, len(primitive_dict['person1']['transl']))
                for person in ['person1', 'person2']:
                    batch_primitive_dict[person] = {}
                    batch_primitive_dict[person] = {key: primitive_dict[person][key][batch_start_idx:batch_end_idx] for key in ['betas', 'transl', 'global_orient', 'body_pose', 'transf_rotmat', 'transf_transl', 'pelvis_delta', 'joints']}
                    batch_primitive_dict[person]['gender'] = primitive_dict[person]['gender']
                    # transf_rotmat[person], transf_transl[person], canonicalized_primitive_dict[person] = self.primitive_utility.canonicalize(batch_primitive_dict[person], use_predicted_joints=True)
                
                transf_rotmat, transf_transl, canonicalized_primitive_dict['person1'] = self.primitive_utility.canonicalize(batch_primitive_dict['person1'], use_predicted_joints=True)
                canonicalized_primitive_dict['person2'] = self.primitive_utility.relative_canonicalize(batch_primitive_dict['person2'], transf_rotmat, transf_transl)

                feature_dict = {}
                motion_tensor = {}
                for person in ['person1', 'person2']:
                    feature_dict[person] = self.primitive_utility.calc_features(canonicalized_primitive_dict[person], use_predicted_joints=True)
                    feature_dict[person]['transl'] = feature_dict[person]['transl'][:, :-1, :]  # [num_primitive, T, 3]
                    feature_dict[person]['poses_6d'] = feature_dict[person]['poses_6d'][:, :-1, :]  # [num_primitive, T, 66]
                    feature_dict[person]['joints'] = feature_dict[person]['joints'][:, :-1, :]  # [num_primitive, T, 22 * 3]
                    motion_tensor[person] = self.dict_to_tensor(feature_dict[person]).detach().cpu()  # [num_primitive, T, D]
                all_mp_data.append(torch.cat((motion_tensor['person1'], motion_tensor['person2']), dim=-1))  # [num_primitive, T, 2*D]

                batch_start_idx = batch_end_idx

        all_mp_data = torch.cat(all_mp_data, dim=0)  # [N, T, 2*D]
        tensor_mean = all_mp_data.mean(dim=[0, 1], keepdim=True)  # [1, 1, 2*D]
        tensor_std = all_mp_data.std(dim=[0, 1], keepdim=True)  # [1, 1, 2*D]
        return tensor_mean.to(self.device), tensor_std.to(self.device)

    def get_primitive(self, seq_data, start_frame, end_frame, skip_text=False):
        """end_frame included"""
        motion_data_p1 = seq_data['motion_p1']
        motion_data_p2 = seq_data['motion_p2']
        primitive_dict = {}
        primitive_dict['person1'] = {
            'gender': motion_data_p1['gender'],
            'betas': motion_data_p1['betas'].expand(1, self.primitive_length + 1, 10),
            'transl': motion_data_p1['transl'][start_frame:end_frame + 1].unsqueeze(0),  # include one more frame for delta feature calculation
            'global_orient': motion_data_p1['global_orient'][start_frame:end_frame + 1].unsqueeze(0),
            'body_pose': motion_data_p1['body_pose'][start_frame:end_frame + 1].unsqueeze(0),
            'pelvis_delta': motion_data_p1['pelvis_delta'].unsqueeze(0),
            'joints': motion_data_p1['joints'][start_frame:end_frame + 1].unsqueeze(0),
            'transf_rotmat': torch.eye(3).unsqueeze(0),
            'transf_transl': torch.zeros(1, 1, 3),
        }
        primitive_dict['person2'] = {
            'gender': motion_data_p2['gender'],
            'betas': motion_data_p2['betas'].expand(1, self.primitive_length + 1, 10),
            'transl': motion_data_p2['transl'][start_frame:end_frame + 1].unsqueeze(0),  # include one more frame for delta feature calculation
            'global_orient': motion_data_p2['global_orient'][start_frame:end_frame + 1].unsqueeze(0),
            'body_pose': motion_data_p2['body_pose'][start_frame:end_frame + 1].unsqueeze(0),
            'pelvis_delta': motion_data_p2['pelvis_delta'].unsqueeze(0),
            'joints': motion_data_p2['joints'][start_frame:end_frame + 1].unsqueeze(0),
            'transf_rotmat': torch.eye(3).unsqueeze(0),
            'transf_transl': torch.zeros(1, 1, 3),
        }

        texts = []
        if not skip_text and 'frame_labels' in seq_data:
            future_start = (start_frame + self.history_length) / self.target_fps
            future_end = (start_frame + self.history_length + self.future_length - 1) / self.target_fps
            # print('text tolerance: ', self.text_tolerance)
            for seg in seq_data['frame_labels']:
                if have_overlap([seg['start_t'], seg['end_t']], [future_start - self.text_tolerance, future_end + self.text_tolerance]):
                    texts.append(seg['proc_label'])
        # print('text label time: ', time.time() - self.time)

        output = {
            'text': random.choice(texts) if len(texts) > 0 else '',
            # 'text': compose_texts_with_and(texts) if len(texts) > 0 else '',
            'primitive_dict': primitive_dict,
        }
        return output

    def get_batch(self, batch_size=8):
        self.time = time.time()
        seq_list = []
        batch_idx = self.get_batch_idx(batch_size)

        for seq_idx in batch_idx:
            seq_data = self.dataset[seq_idx]
            num_frames = len(seq_data['motion_p1']['transl'])
            if self.prob_static > 0 and random.random() < self.prob_static:
                static_frame = random.randint(0, num_frames - 1) # right end inclusive
                motion_data_p1 = seq_data['motion_p1']
                motion_data_p2 = seq_data['motion_p2']
                primitive_length = self.primitive_length
                primitive_dict = {}
                primitive_dict['person1'] = {
                    'gender': motion_data_p1['gender'],
                    'betas': motion_data_p1['betas'].expand(1, primitive_length + 1, 10),
                    'transl': motion_data_p1['transl'][[static_frame]].expand(primitive_length + 1, -1).unsqueeze(0),
                    # include one more frame for delta feature calculation
                    'global_orient':
                        motion_data_p1['global_orient'][[static_frame]].repeat(primitive_length + 1, 1, 1).unsqueeze(0),
                    'body_pose':
                        motion_data_p1['body_pose'][[static_frame]].repeat(primitive_length + 1, 1, 1, 1).unsqueeze(0),
                    'pelvis_delta': motion_data_p1['pelvis_delta'].unsqueeze(0),
                    'joints': motion_data_p1['joints'][[static_frame]].repeat(primitive_length + 1, 1, 1).unsqueeze(0),
                    'transf_rotmat': torch.eye(3).unsqueeze(0),
                    'transf_transl': torch.zeros(1, 1, 3),
                }
                primitive_dict['person2'] = {
                    'gender': motion_data_p2['gender'],
                    'betas': motion_data_p2['betas'].expand(1, primitive_length + 1, 10),
                    'transl': motion_data_p2['transl'][[static_frame]].expand(primitive_length + 1, -1).unsqueeze(0),
                    # include one more frame for delta feature calculation
                    'global_orient':
                        motion_data_p2['global_orient'][[static_frame]].repeat(primitive_length + 1, 1, 1).unsqueeze(0),
                    'body_pose':
                        motion_data_p2['body_pose'][[static_frame]].repeat(primitive_length + 1, 1, 1, 1).unsqueeze(0),
                    'pelvis_delta': motion_data_p2['pelvis_delta'].unsqueeze(0),
                    'joints': motion_data_p2['joints'][[static_frame]].repeat(primitive_length + 1, 1, 1).unsqueeze(0),
                    'transf_rotmat': torch.eye(3).unsqueeze(0),
                    'transf_transl': torch.zeros(1, 1, 3),
                }
                primitive_data = {
                    'text': '',
                    'primitive_dict': primitive_dict
                }
                primitive_data_list = [primitive_data] * self.num_primitive
                # print('get static sequenece')
            elif self.seed_only:  # only take the first primitive for predicting initial seed
                # print('get seed')
                frame_labels = []
                for seg in seq_data['frame_labels']:
                    start_frame = int(seg['start_t'] * self.target_fps)
                    end_frame = start_frame + self.primitive_length
                    if end_frame < num_frames:
                        frame_labels.append((start_frame, end_frame, seg['proc_label']))
                start_frame, end_frame, text = random.choice(frame_labels)

                motion_data_p1 = seq_data['motion_p1']
                motion_data_p2 = seq_data['motion_p2']
                primitive_dict = {}
                primitive_dict['person1'] = {
                    'gender': motion_data_p1['gender'],
                    'betas': motion_data_p1['betas'].expand(1, self.primitive_length + 1, 10),
                    'transl': motion_data_p1['transl'][start_frame:end_frame + 1].unsqueeze(0),
                    # include one more frame for delta feature calculation
                    'global_orient': motion_data_p1['global_orient'][start_frame:end_frame + 1].unsqueeze(0),
                    'body_pose': motion_data_p1['body_pose'][start_frame:end_frame + 1].unsqueeze(0),
                    'pelvis_delta': motion_data_p1['pelvis_delta'].unsqueeze(0),
                    'joints': motion_data_p1['joints'][start_frame:end_frame + 1].unsqueeze(0),
                    'transf_rotmat': torch.eye(3).unsqueeze(0),
                    'transf_transl': torch.zeros(1, 1, 3),
                }
                primitive_dict['person2'] = {
                    'gender': motion_data_p2['gender'],
                    'betas': motion_data_p2['betas'].expand(1, self.primitive_length + 1, 10),
                    'transl': motion_data_p2['transl'][start_frame:end_frame + 1].unsqueeze(0),
                    # include one more frame for delta feature calculation
                    'global_orient': motion_data_p2['global_orient'][start_frame:end_frame + 1].unsqueeze(0),
                    'body_pose': motion_data_p2['body_pose'][start_frame:end_frame + 1].unsqueeze(0),
                    'pelvis_delta': motion_data_p2['pelvis_delta'].unsqueeze(0),
                    'joints': motion_data_p2['joints'][start_frame:end_frame + 1].unsqueeze(0),
                    'transf_rotmat': torch.eye(3).unsqueeze(0),
                    'transf_transl': torch.zeros(1, 1, 3),
                }

                primitive_data_list = [
                    {
                        'text': text,
                        'primitive_dict': primitive_dict,
                    }
                ]
            else:
                if 'text' in self.weight_scheme:
                    start_frame = random.choices(range(num_frames - self.seq_length + 1), weights=seq_data['frame_weights'], k=1)[0]
                else:
                    start_frame = random.randint(0, num_frames - self.seq_length)  # [0, num_frames - seq_length], right end inclusive
                primitive_data_list = []
                for frame_idx in range(start_frame, start_frame + self.seq_length - self.primitive_length, self.future_length):
                    primitive_data = self.get_primitive(seq_data, frame_idx, frame_idx + self.primitive_length)
                    primitive_data_list.append(primitive_data)
            seq_list.append(primitive_data_list)

        # sort batch by gender
        batch = None
        for gender in ['female', 'male']:
            gender_idx = [idx for idx in range(len(seq_list)) if seq_list[idx][0]['primitive_dict']['person1']['gender'] == gender]
            if len(gender_idx) == 0:
                continue
            gender_seq_list = [seq_list[i] for i in gender_idx]
            gender_batch_size = len(gender_idx)
            gender_batch = []
            
            gender_seq_texts = None
            gender_seq_dict = None
            for primitive_idx in range(self.num_primitive):
                primitive_texts = [mp_seq[primitive_idx]['text'] for mp_seq in gender_seq_list]
                primitive_dict = {}
                for person in ['person1', 'person2']:
                    primitive_dict[person] = {}
                    primitive_dict[person]['gender'] = gender
                    for key in ['betas', 'transl', 'global_orient', 'body_pose', 'transf_rotmat', 'transf_transl', 'pelvis_delta', 'joints']:
                        primitive_dict[person][key] = torch.cat([mp_seq[primitive_idx]['primitive_dict'][person][key] for mp_seq in gender_seq_list], dim=0)
                gender_seq_texts = primitive_texts if gender_seq_texts is None else gender_seq_texts + primitive_texts
                
                if gender_seq_dict is None:
                    gender_seq_dict = primitive_dict
                else:
                    for person in ['person1', 'person2']:
                        for key in ['betas', 'transl', 'global_orient', 'body_pose', 'transf_rotmat', 'transf_transl', 'pelvis_delta', 'joints']:
                            gender_seq_dict[person][key] = torch.cat([gender_seq_dict[person][key], primitive_dict[person][key]], dim=0)

            for person in ['person1', 'person2']:
                gender_seq_dict[person] = tensor_dict_to_device(gender_seq_dict[person], self.device)
            canonicalized_primitive_dict = {}
            transf_rotmat, transf_transl, canonicalized_primitive_dict['person1'] = self.primitive_utility.canonicalize(gender_seq_dict['person1'], use_predicted_joints=True)
            canonicalized_primitive_dict['person2'] = self.primitive_utility.relative_canonicalize(gender_seq_dict['person2'], transf_rotmat, transf_transl)
            # print(f'{gender}:canonicalize time: ', time.time() - self.time)
            feature_dict = {}
            for person in ['person1', 'person2']:
                feature_dict[person] = self.primitive_utility.calc_features(canonicalized_primitive_dict[person], use_predicted_joints=True)
                # print(f'{gender}:calc feature time: ', time.time() - self.time)
                feature_dict[person]['transl'] = feature_dict[person]['transl'][:, :-1, :]  # [B*num_mp, T, 3]
                feature_dict[person]['poses_6d'] = feature_dict[person]['poses_6d'][:, :-1, :]  # [B*num_mp, T, 66]
                feature_dict[person]['joints'] = feature_dict[person]['joints'][:, :-1, :]  # [B*num_mp, T, 22 * 3]
            motion_tensor_normalized = self.normalize(torch.cat((self.dict_to_tensor(feature_dict['person1']), self.dict_to_tensor(feature_dict['person2'])), dim=-1))  # [B*num_mp, T, 2*D]
            motion_tensor_normalized = motion_tensor_normalized.permute(0, 2, 1).unsqueeze(2)  # [B*num_mp, 2*D, 1, T]
            history_mask = torch.zeros_like(motion_tensor_normalized, dtype=torch.bool, device=self.device)
            history_mask[..., :self.cfg.history_length] = True
            history_motion = torch.zeros_like(motion_tensor_normalized, dtype=torch.float32, device=self.device)
            history_motion[..., :self.cfg.history_length] = motion_tensor_normalized[..., :self.cfg.history_length]

            for primitive_idx in range(self.num_primitive):
                start_idx = primitive_idx * gender_batch_size
                end_idx = (primitive_idx + 1) * gender_batch_size
                primitive_texts = gender_seq_texts[start_idx:end_idx]
                unseen_texts = [text for text in primitive_texts if text not in self.text_embedding_dict]
                if len(unseen_texts) > 0:
                    self.update_text_embedding_dict(unseen_texts)
                text_embedding = torch.stack([self.text_embedding_dict[text] for text in primitive_texts], dim=0)  # [B, 512]
                gender_batch.append(
                    {
                        'texts': primitive_texts,
                        'text_embedding': text_embedding,
                        'gender_p1': [gender_seq_dict['person1']['gender']] * gender_batch_size,
                        'betas_p1': gender_seq_dict['person1']['betas'][start_idx:end_idx, :-1, :10],
                        'gender_p2': [gender_seq_dict['person2']['gender']] * gender_batch_size,
                        'betas_p2': gender_seq_dict['person2']['betas'][start_idx:end_idx, :-1, :10],
                        'motion_tensor_normalized': motion_tensor_normalized[start_idx:end_idx, ...], # [B, 2*D, 1, T]
                        'history_motion': history_motion[start_idx:end_idx, ...],
                        'history_mask': history_mask[start_idx:end_idx, ...],
                        'history_length': self.cfg.history_length,
                        'future_length': self.cfg.future_length,
                    }
                )

            if batch is None:
                batch = gender_batch
            else:  # concatenate different gender batch
                for primitive_idx in range(self.num_primitive):
                    for key in ['texts', 'gender_p1', 'gender_p2']:
                        batch[primitive_idx][key] = batch[primitive_idx][key] + gender_batch[primitive_idx][key]
                    for key in ['betas_p1', 'betas_p2', 'motion_tensor_normalized', 'history_motion', 'history_mask', 'text_embedding']:
                        batch[primitive_idx][key] = torch.cat([batch[primitive_idx][key], gender_batch[primitive_idx][key]], dim=0)
            # print(f'{gender} batch time: ', time.time() - self.time)

        return batch
    
    def get_item(self, idx):
        seq_data = self.dataset[idx]
        data_dict = {}
        data_dict['gender'] = [seq_data['motion_p1']['gender'], seq_data['motion_p2']['gender']]
        data_dict['frame_labels'] = seq_data['frame_labels']
        for key in ['betas', 'transl', 'global_orient', 'body_pose', 'pelvis_delta', 'joints']:
            data_dict[key] = torch.cat((seq_data['motion_p1'][key], seq_data['motion_p2'][key]), dim=-1)  # [T, 2*D]

        return data_dict


class InterHumanDatasetV2(WeightedPrimitiveSequenceDataset):
    def __init__(self, dataset_name='interhuman',
                 dataset_path='./data/InterHuman/seq_data_zero_male',
                 cfg_path='./config_files/config_hydra/motion_primitive/hml_mp_h2_f8_r4.yaml',
                 split="train",
                 device='cuda',
                 weight_scheme='uniform',
                 prob_static=0.0,
                 enforce_gender=None,
                 enforce_zero_beta=None,
                 load_data=True,
                 text_tolerance=0.0,
                 body_type='smplh',
                 seed_only=False,
                 use_frame_weights=True,
                 mode='sep', # 'sep' or 'merged'
                 text_encoder = 'clip',
                 text_sep = False,
                 max_segs = 20,
                 **kwargs):
        self.dataset_name = dataset_name
        self.dataset_path = dataset_path
        self.split = split
        self.device = device
        self.weight_scheme = weight_scheme
        self.prob_static = prob_static
        self.enforce_gender = enforce_gender
        self.enforce_zero_beta = enforce_zero_beta
        self.text_tolerance = text_tolerance
        self.seed_only = seed_only
        self.mode = mode
        self.text_sep = text_sep
        self.max_segs = max_segs
        print('enforce_gender: ', enforce_gender)
        print('enforce_zero_beta: ', enforce_zero_beta)

        motion_repr = kwargs.get('motion_repr', 
                                 {'transl': 3,
                                  'poses_6d': 22 * 6,
                                  'transl_delta': 3,
                                  'global_orient_delta_6d': 6,
                                  'joints': 22 * 3,
                                  'joints_delta': 22 * 3,
                                  })
        self.primitive_utility = PrimitiveUtility(device=self.device, body_type=body_type, motion_repr=motion_repr)
        self.motion_repr = self.primitive_utility.motion_repr

        # cfg_path = Path(dataset_path, 'config.yaml')
        with open(cfg_path, 'r') as f:
            self.cfg = OmegaConf.load(f)
        self.target_fps = self.cfg.fps
        # self.downsample_rate = 120 // self.target_fps
        self.history_length = self.cfg.history_length
        self.future_length = self.cfg.future_length
        self.primitive_length = self.history_length + self.future_length
        self.num_primitive = self.cfg.num_primitive
        if seed_only:
            assert self.num_primitive == 1
        self.seq_length = self.history_length + self.future_length * self.num_primitive + 1

        if load_data:
            with open(pjoin(dataset_path, f'{split}.pkl'), 'rb') as f:
                dataset = pickle.load(f)
            dataset = [data for data in dataset if len(data['motion_p1']['trans']) >= self.seq_length]

            for data in dataset:
                assert self.enforce_gender == data['motion_p1']['gender']
                assert self.enforce_gender == data['motion_p2']['gender']
                assert self.enforce_zero_beta
                gender_p1 = self.enforce_gender if self.enforce_gender is not None else data['motion_p1']['gender']
                gender_p2 = self.enforce_gender if self.enforce_gender is not None else data['motion_p2']['gender']
                betas_p1 =torch.from_numpy(data['motion_p1']['betas'].astype(np.float32))
                betas_p2 =torch.from_numpy(data['motion_p2']['betas'].astype(np.float32))
                if self.enforce_zero_beta:
                    betas_p1 = torch.zeros_like(betas_p1)
                    betas_p2 = torch.zeros_like(betas_p2)
                
                transl_p1 = torch.from_numpy(data['motion_p1']['trans'].astype(np.float32))
                poses_p1 = torch.from_numpy(data['motion_p1']['poses'].astype(np.float32))
                transl_p2 = torch.from_numpy(data['motion_p2']['trans'].astype(np.float32))
                poses_p2 = torch.from_numpy(data['motion_p2']['poses'].astype(np.float32))
                
                global_orient_p1 = transforms.axis_angle_to_matrix(poses_p1[:, :3])  # [T, 3, 3]
                body_pose_p1 = transforms.axis_angle_to_matrix(poses_p1[:, 3:66].reshape(-1, 21, 3))  # [T, 21, 3, 3]
                pelvis_delta_p1 = torch.from_numpy(data['motion_p1']['pelvis_delta'].astype(np.float32))  # [3]
                joints_p1 = torch.from_numpy(data['motion_p1']['joints'].astype(np.float32))  # [T, 22, 3]
                global_orient_p2 = transforms.axis_angle_to_matrix(poses_p2[:, :3])  # [T, 3, 3]
                body_pose_p2 = transforms.axis_angle_to_matrix(poses_p2[:, 3:66].reshape(-1, 21, 3))  # [T, 21, 3, 3]
                pelvis_delta_p2 = torch.from_numpy(data['motion_p2']['pelvis_delta'].astype(np.float32))  # [3]
                joints_p2 = torch.from_numpy(data['motion_p2']['joints'].astype(np.float32))  # [T, 22, 3]
                
                data['motion_p1'] = {
                    'gender': gender_p1,
                    'betas': betas_p1,
                    'transl': transl_p1,
                    'global_orient': global_orient_p1,
                    'body_pose': body_pose_p1,
                    'pelvis_delta': pelvis_delta_p1,
                    'joints': joints_p1,
                }
                data['motion_p2'] = {
                    'gender': gender_p2,
                    'betas': betas_p2,
                    'transl': transl_p2,
                    'global_orient': global_orient_p2,
                    'body_pose': body_pose_p2,
                    'pelvis_delta': pelvis_delta_p2,
                    'joints': joints_p2,
                }
            print('num of sequences: ', len(dataset))
            # assign sampling weights to each sequence


            for data in dataset:
                if 'uniform' in weight_scheme:
                    data['weight'] = 1.0
                elif 'length' in weight_scheme:
                    data['weight'] = len(data['motion_p1']['trans'])
            print('finish first assigning seq weights')

            # overfit using one sequence
            # if 'overfit' in weight_scheme:
            #     seq_id = int(weight_scheme.split('overfit:')[-1].split('_')[0])
            #     for idx, data in enumerate(dataset):
            #         if idx == seq_id:
            #             data['weight'] = 1.0
            #         else:
            #             data['weight'] = 0.0

            seq_weights = np.array([data['weight'] for data in dataset])
            seq_weights = seq_weights / seq_weights.sum()

            self.dataset = dataset
            self.seq_weights = seq_weights

        # load or calc mean and std
        self.tensor_mean_device_dict = {}
        file_name = f'mean_std_h{self.history_length}_f{self.future_length}'
        # file_name = f'mean_std_h{self.history_length}_f{self.future_length}_allcanonicalized'
        # TODO: use different mean and std when enforce gender and beta
        # if self.enforce_gender is not None:
        #     file_name = file_name + f'_{self.enforce_gender}'
        # if self.enforce_zero_beta:
        #     file_name = file_name + '_zero_beta'
        mean_std_path = Path(dataset_path, f'{file_name}.pkl')
        if mean_std_path.exists():
            print(f'loading mean and std from {mean_std_path}')
            with open(mean_std_path, 'rb') as f:
                self.tensor_mean, self.tensor_std = pickle.load(f)  # [1, 1, D]
        else:
            assert self.split == 'train'
            print('calculating mean and std using train split')
            self.tensor_mean, self.tensor_std = self.calc_mean_std()
            with open(mean_std_path, 'wb') as f:
                pickle.dump((self.tensor_mean.detach().cpu(), self.tensor_std.detach().cpu()), f)

        # load clip model, get train text embeddings
        if text_encoder == 'clip':
            self.clip_model = load_and_freeze_clip(clip_version='ViT-B/32', device=self.device)
        else: # text_encoder == 't5'
            self.t5_model = load_and_freeze_t5_encoder(model_name='google/flan-t5-xxl', device=self.device)
        if text_sep:
            self.embedding_path = embedding_path = Path(dataset_path, f'{split}_text_embedding_dict_{text_encoder}_textsep.pkl')
        else:
            self.embedding_path = embedding_path = Path(dataset_path, f'{split}_text_embedding_dict_{text_encoder}.pkl')
        if embedding_path.exists():
            print(f'loading text embeddings from {embedding_path}')
            with open(embedding_path, 'rb') as f:
                self.text_embedding_dict = pickle.load(f)
            if text_sep:
                with open(Path(str(embedding_path).replace('text_embedding_dict', 'text_mask_dict')), 'rb') as f:
                    self.text_mask_dict = pickle.load(f)
        else:
            print('calculating text embeddings')
            raw_texts = []
            for data in self.dataset:
                if 'frame_labels' in data:
                    raw_texts.extend([seg['proc_label'] for seg in data['frame_labels']])
            raw_texts = list(set(raw_texts))
            num_texts = len(raw_texts)
            print('num of unique texts: ', len(raw_texts))
            # get text embeddings by batch
            text_embeddings = []
            text_mask = []
            batch_start_idx = 0
            while batch_start_idx < num_texts:
                batch_end_idx = min(batch_start_idx + 256, num_texts)
                if text_encoder == 'clip':
                    text_embeddings_temp = encode_text(self.clip_model, raw_texts[batch_start_idx:batch_end_idx], text_sep=text_sep, max_segs=max_segs)
                    if text_sep:
                        text_embeddings.append(text_embeddings_temp[0])
                        text_mask.append(text_embeddings_temp[1])
                    else:
                        text_embeddings.append(text_embeddings_temp)
                else: # text_encoder == 't5'
                    text_embeddings.append(encode_text_t5(self.t5_model[0], self.t5_model[1], raw_texts[batch_start_idx:batch_end_idx]))
                batch_start_idx = batch_end_idx
            text_embeddings = torch.cat(text_embeddings, dim=0).detach().cpu().numpy()
            print(text_embeddings.shape)
            self.text_embedding_dict = {raw_texts[idx]: text_embeddings[idx] for idx in range(num_texts)}
            self.text_embedding_dict[''] = np.zeros(512).astype(np.float32)  # for empty text have zero embedding, compatible with mdm text masking
            with open(embedding_path, 'wb') as f:
                pickle.dump(self.text_embedding_dict, f)
            if text_sep:
                text_mask = torch.cat(text_mask, dim=0).detach().cpu().numpy()
                self.text_mask_dict = {raw_texts[idx]: text_mask[idx] for idx in range(num_texts)}
                self.text_mask_dict[''] = np.zeros(max_segs).astype(np.bool_)  # for empty text have zero embedding, compatible with mdm text masking
                with open(Path(str(embedding_path).replace('text_embedding_dict', 'text_mask_dict')), 'wb') as f:
                    pickle.dump(self.text_mask_dict, f)
        for key in self.text_embedding_dict:
            self.text_embedding_dict[key] = torch.from_numpy(self.text_embedding_dict[key]).to(dtype=torch.float32, device=self.device)
            if text_sep:
                self.text_mask_dict[key] = torch.from_numpy(self.text_mask_dict[key]).to(dtype=torch.bool, device=self.device)

    def update_text_embedding_dict(self, new_texts, text_sep=False, max_segs=20):
        new_text_embeddings = encode_text(self.clip_model, new_texts, text_sep=text_sep, max_segs=max_segs)
        for idx, text in enumerate(new_texts):
            if text_sep:
                self.text_embedding_dict[text] = new_text_embeddings[idx][0]
                self.text_mask_dict[text] = new_text_embeddings[idx][1]
            else:
                self.text_embedding_dict[text] = new_text_embeddings[idx]

    def calc_mean_std(self, batch_size=512):
        if self.future_length == 1:
            batch_size = min(batch_size, 64)
        all_mp_data = []
        for seq_data in self.dataset:
            motion_data_p1 = seq_data['motion_p1']
            num_frames = motion_data_p1['transl'].shape[0]
            primitive_data_list = []
            for start_frame in range(0, num_frames - self.primitive_length, self.future_length):
                end_frame = start_frame + self.primitive_length
                primitive_data_list.append(self.get_primitive(seq_data, start_frame, end_frame, skip_text=True))
                
            primitive_dict = {}
            for person in ['person1', 'person2']:
                primitive_dict[person] = {}
                primitive_dict[person]['gender'] = {primitive_data_list[0]['primitive_dict'][person]['gender']}
                for key in ['betas', 'transl', 'global_orient', 'body_pose', 'transf_rotmat', 'transf_transl', 'pelvis_delta', 'joints']:
                    primitive_dict[person][key] = torch.cat([data['primitive_dict'][person][key] for data in primitive_data_list], dim=0)
                primitive_dict[person] = tensor_dict_to_device(primitive_dict[person], self.device)

            # split primitive_dict into batches
            batch_start_idx = 0
            while batch_start_idx < len(primitive_dict['person1']['transl']):
                batch_primitive_dict = {}
                canonicalized_primitive_dict = {}
                batch_end_idx = min(batch_start_idx + batch_size, len(primitive_dict['person1']['transl']))
                for person in ['person1', 'person2']:
                    batch_primitive_dict[person] = {}
                    batch_primitive_dict[person] = {key: primitive_dict[person][key][batch_start_idx:batch_end_idx] for key in ['betas', 'transl', 'global_orient', 'body_pose', 'transf_rotmat', 'transf_transl', 'pelvis_delta', 'joints']}
                    batch_primitive_dict[person]['gender'] = primitive_dict[person]['gender']
                    # transf_rotmat[person], transf_transl[person], canonicalized_primitive_dict[person] = self.primitive_utility.canonicalize(batch_primitive_dict[person], use_predicted_joints=True)
                
                transf_rotmat, transf_transl, canonicalized_primitive_dict['person1'] = self.primitive_utility.canonicalize(batch_primitive_dict['person1'], use_predicted_joints=True)
                canonicalized_primitive_dict['person2'] = self.primitive_utility.relative_canonicalize(batch_primitive_dict['person2'], transf_rotmat, transf_transl)

                feature_dict = {}
                motion_tensor = {}
                for person in ['person1', 'person2']:
                    feature_dict[person] = self.primitive_utility.calc_features(canonicalized_primitive_dict[person], use_predicted_joints=True)
                    feature_dict[person]['transl'] = feature_dict[person]['transl'][:, :-1, :]          # [num_primitive, T, 3]
                    feature_dict[person]['poses_6d'] = feature_dict[person]['poses_6d'][:, :-1, :]      # [num_primitive, T, 66]
                    feature_dict[person]['joints'] = feature_dict[person]['joints'][:, :-1, :]          # [num_primitive, T, 22 * 3]
                    motion_tensor[person] = self.dict_to_tensor(feature_dict[person]).detach().cpu()    # [num_primitive, T, D]
                    all_mp_data.append(motion_tensor[person])                                           # [num_primitive, T, D]

                batch_start_idx = batch_end_idx

        all_mp_data = torch.cat(all_mp_data, dim=0)                 # [2*N, T, D]
        tensor_mean = all_mp_data.mean(dim=[0, 1], keepdim=True)    # [1, 1, D]
        tensor_std = all_mp_data.std(dim=[0, 1], keepdim=True)      # [1, 1, D]
        return tensor_mean.to(self.device), tensor_std.to(self.device)

    def get_primitive(self, seq_data, start_frame, end_frame, skip_text=False):
        """end_frame included"""
        motion_data_p1 = seq_data['motion_p1']
        motion_data_p2 = seq_data['motion_p2']
        primitive_dict = {}
        primitive_dict['person1'] = {
            'gender': motion_data_p1['gender'],
            'betas': motion_data_p1['betas'].expand(1, self.primitive_length + 1, 10),
            'transl': motion_data_p1['transl'][start_frame:end_frame + 1].unsqueeze(0),  # include one more frame for delta feature calculation
            'global_orient': motion_data_p1['global_orient'][start_frame:end_frame + 1].unsqueeze(0),
            'body_pose': motion_data_p1['body_pose'][start_frame:end_frame + 1].unsqueeze(0),
            'pelvis_delta': motion_data_p1['pelvis_delta'].unsqueeze(0),
            'joints': motion_data_p1['joints'][start_frame:end_frame + 1].unsqueeze(0),
            'transf_rotmat': torch.eye(3).unsqueeze(0),
            'transf_transl': torch.zeros(1, 1, 3),
        }
        primitive_dict['person2'] = {
            'gender': motion_data_p2['gender'],
            'betas': motion_data_p2['betas'].expand(1, self.primitive_length + 1, 10),
            'transl': motion_data_p2['transl'][start_frame:end_frame + 1].unsqueeze(0),  # include one more frame for delta feature calculation
            'global_orient': motion_data_p2['global_orient'][start_frame:end_frame + 1].unsqueeze(0),
            'body_pose': motion_data_p2['body_pose'][start_frame:end_frame + 1].unsqueeze(0),
            'pelvis_delta': motion_data_p2['pelvis_delta'].unsqueeze(0),
            'joints': motion_data_p2['joints'][start_frame:end_frame + 1].unsqueeze(0),
            'transf_rotmat': torch.eye(3).unsqueeze(0),
            'transf_transl': torch.zeros(1, 1, 3),
        }

        texts = []
        if not skip_text and 'frame_labels' in seq_data:
            future_start = (start_frame + self.history_length) / self.target_fps
            future_end = (start_frame + self.history_length + self.future_length - 1) / self.target_fps
            # print('text tolerance: ', self.text_tolerance)
            for seg in seq_data['frame_labels']:
                if have_overlap([seg['start_t'], seg['end_t']], [future_start - self.text_tolerance, future_end + self.text_tolerance]):
                    texts.append(seg['proc_label'])
        # print('text label time: ', time.time() - self.time)

        output = {
            'text': random.choice(texts) if len(texts) > 0 else '',
            # 'text': compose_texts_with_and(texts) if len(texts) > 0 else '',
            'primitive_dict': primitive_dict,
        }
        return output

    def get_batch(self, batch_size=8):
        self.time = time.time()
        seq_list = []
        batch_idx = self.get_batch_idx(batch_size)
        add_key_list = ['texts', 'gender'] if self.mode=='sep' else ['texts', 'gender_p1', 'gender_p2']
        cat_key_list = ['betas', 'motion_tensor_normalized', 'history_motion', 'history_mask', 'text_embedding'] if self.mode=='sep' else ['betas_p1', 'betas_p2', 'motion_tensor_normalized', 'history_motion', 'history_mask', 'text_embedding']
        
        for seq_idx in batch_idx:
            seq_data = self.dataset[seq_idx]
            num_frames = len(seq_data['motion_p1']['transl'])
            if self.prob_static > 0 and random.random() < self.prob_static:
                static_frame = random.randint(0, num_frames - 1) # right end inclusive
                motion_data_p1 = seq_data['motion_p1']
                motion_data_p2 = seq_data['motion_p2']
                primitive_length = self.primitive_length
                primitive_dict = {}
                primitive_dict['person1'] = {
                    'gender': motion_data_p1['gender'],
                    'betas': motion_data_p1['betas'].expand(1, primitive_length + 1, 10),
                    'transl': motion_data_p1['transl'][[static_frame]].expand(primitive_length + 1, -1).unsqueeze(0),
                    # include one more frame for delta feature calculation
                    'global_orient':
                        motion_data_p1['global_orient'][[static_frame]].repeat(primitive_length + 1, 1, 1).unsqueeze(0),
                    'body_pose':
                        motion_data_p1['body_pose'][[static_frame]].repeat(primitive_length + 1, 1, 1, 1).unsqueeze(0),
                    'pelvis_delta': motion_data_p1['pelvis_delta'].unsqueeze(0),
                    'joints': motion_data_p1['joints'][[static_frame]].repeat(primitive_length + 1, 1, 1).unsqueeze(0),
                    'transf_rotmat': torch.eye(3).unsqueeze(0),
                    'transf_transl': torch.zeros(1, 1, 3),
                }
                primitive_dict['person2'] = {
                    'gender': motion_data_p2['gender'],
                    'betas': motion_data_p2['betas'].expand(1, primitive_length + 1, 10),
                    'transl': motion_data_p2['transl'][[static_frame]].expand(primitive_length + 1, -1).unsqueeze(0),
                    # include one more frame for delta feature calculation
                    'global_orient':
                        motion_data_p2['global_orient'][[static_frame]].repeat(primitive_length + 1, 1, 1).unsqueeze(0),
                    'body_pose':
                        motion_data_p2['body_pose'][[static_frame]].repeat(primitive_length + 1, 1, 1, 1).unsqueeze(0),
                    'pelvis_delta': motion_data_p2['pelvis_delta'].unsqueeze(0),
                    'joints': motion_data_p2['joints'][[static_frame]].repeat(primitive_length + 1, 1, 1).unsqueeze(0),
                    'transf_rotmat': torch.eye(3).unsqueeze(0),
                    'transf_transl': torch.zeros(1, 1, 3),
                }
                primitive_data = {
                    'text': '',
                    'primitive_dict': primitive_dict
                }
                primitive_data_list = [primitive_data] * self.num_primitive
                # print('get static sequenece')
            elif self.seed_only:  # only take the first primitive for predicting initial seed
                # print('get seed')
                frame_labels = []
                for seg in seq_data['frame_labels']:
                    start_frame = int(seg['start_t'] * self.target_fps)
                    end_frame = start_frame + self.primitive_length
                    if end_frame < num_frames:
                        frame_labels.append((start_frame, end_frame, seg['proc_label']))
                start_frame, end_frame, text = random.choice(frame_labels)

                motion_data_p1 = seq_data['motion_p1']
                motion_data_p2 = seq_data['motion_p2']
                primitive_dict = {}
                primitive_dict['person1'] = {
                    'gender': motion_data_p1['gender'],
                    'betas': motion_data_p1['betas'].expand(1, self.primitive_length + 1, 10),
                    'transl': motion_data_p1['transl'][start_frame:end_frame + 1].unsqueeze(0),
                    # include one more frame for delta feature calculation
                    'global_orient': motion_data_p1['global_orient'][start_frame:end_frame + 1].unsqueeze(0),
                    'body_pose': motion_data_p1['body_pose'][start_frame:end_frame + 1].unsqueeze(0),
                    'pelvis_delta': motion_data_p1['pelvis_delta'].unsqueeze(0),
                    'joints': motion_data_p1['joints'][start_frame:end_frame + 1].unsqueeze(0),
                    'transf_rotmat': torch.eye(3).unsqueeze(0),
                    'transf_transl': torch.zeros(1, 1, 3),
                }
                primitive_dict['person2'] = {
                    'gender': motion_data_p2['gender'],
                    'betas': motion_data_p2['betas'].expand(1, self.primitive_length + 1, 10),
                    'transl': motion_data_p2['transl'][start_frame:end_frame + 1].unsqueeze(0),
                    # include one more frame for delta feature calculation
                    'global_orient': motion_data_p2['global_orient'][start_frame:end_frame + 1].unsqueeze(0),
                    'body_pose': motion_data_p2['body_pose'][start_frame:end_frame + 1].unsqueeze(0),
                    'pelvis_delta': motion_data_p2['pelvis_delta'].unsqueeze(0),
                    'joints': motion_data_p2['joints'][start_frame:end_frame + 1].unsqueeze(0),
                    'transf_rotmat': torch.eye(3).unsqueeze(0),
                    'transf_transl': torch.zeros(1, 1, 3),
                }

                primitive_data_list = [
                    {
                        'text': text,
                        'primitive_dict': primitive_dict,
                    }
                ]
            else:
                if 'text' in self.weight_scheme:
                    start_frame = random.choices(range(num_frames - self.seq_length + 1), weights=seq_data['frame_weights'], k=1)[0]
                else:
                    start_frame = random.randint(0, num_frames - self.seq_length)  # [0, num_frames - seq_length], right end inclusive
                primitive_data_list = []
                for frame_idx in range(start_frame, start_frame + self.seq_length - self.primitive_length, self.future_length):
                    primitive_data = self.get_primitive(seq_data, frame_idx, frame_idx + self.primitive_length)
                    primitive_data_list.append(primitive_data)
            seq_list.append(primitive_data_list)

        # sort batch by gender
        batch = None
        for gender in ['female', 'male']:
            gender_idx = [idx for idx in range(len(seq_list)) if seq_list[idx][0]['primitive_dict']['person1']['gender'] == gender]
            if len(gender_idx) == 0:
                continue
            gender_seq_list = [seq_list[i] for i in gender_idx]
            gender_batch_size = len(gender_idx)
            gender_batch = []
            
            gender_seq_texts = None
            gender_seq_dict = None
            for primitive_idx in range(self.num_primitive):
                primitive_texts = [mp_seq[primitive_idx]['text'] for mp_seq in gender_seq_list]
                primitive_dict = {}
                for person in ['person1', 'person2']:
                    primitive_dict[person] = {}
                    primitive_dict[person]['gender'] = gender
                    for key in ['betas', 'transl', 'global_orient', 'body_pose', 'transf_rotmat', 'transf_transl', 'pelvis_delta', 'joints']:
                        primitive_dict[person][key] = torch.cat([mp_seq[primitive_idx]['primitive_dict'][person][key] for mp_seq in gender_seq_list], dim=0)
                gender_seq_texts = primitive_texts if gender_seq_texts is None else gender_seq_texts + primitive_texts
                
                if gender_seq_dict is None:
                    gender_seq_dict = primitive_dict
                else:
                    for person in ['person1', 'person2']:
                        for key in ['betas', 'transl', 'global_orient', 'body_pose', 'transf_rotmat', 'transf_transl', 'pelvis_delta', 'joints']:
                            gender_seq_dict[person][key] = torch.cat([gender_seq_dict[person][key], primitive_dict[person][key]], dim=0)

            for person in ['person1', 'person2']:
                gender_seq_dict[person] = tensor_dict_to_device(gender_seq_dict[person], self.device)
            canonicalized_primitive_dict = {}
            transf_rotmat, transf_transl, canonicalized_primitive_dict['person1'] = self.primitive_utility.canonicalize(gender_seq_dict['person1'], use_predicted_joints=True)
            canonicalized_primitive_dict['person2'] = self.primitive_utility.relative_canonicalize(gender_seq_dict['person2'], transf_rotmat, transf_transl)
            # print(f'{gender}:canonicalize time: ', time.time() - self.time)
            feature_dict = {}
            for person in ['person1', 'person2']:
                feature_dict[person] = self.primitive_utility.calc_features(canonicalized_primitive_dict[person], use_predicted_joints=True)
                # print(f'{gender}:calc feature time: ', time.time() - self.time)
                feature_dict[person]['transl'] = feature_dict[person]['transl'][:, :-1, :]  # [B*num_mp, T, 3]
                feature_dict[person]['poses_6d'] = feature_dict[person]['poses_6d'][:, :-1, :]  # [B*num_mp, T, 66]
                feature_dict[person]['joints'] = feature_dict[person]['joints'][:, :-1, :]  # [B*num_mp, T, 22 * 3]
            
            if self.mode == 'sep':
                for person in ['person1', 'person2']:
                    motion_tensor_normalized = self.normalize(self.dict_to_tensor(feature_dict[person]))     # [B*num_mp, T, D]
                    motion_tensor_normalized = motion_tensor_normalized.permute(0, 2, 1).unsqueeze(2)           # [B*num_mp, D, 1, T]
                    history_mask = torch.zeros_like(motion_tensor_normalized, dtype=torch.bool, device=self.device)
                    history_mask[..., :self.cfg.history_length] = True
                    history_motion = torch.zeros_like(motion_tensor_normalized, dtype=torch.float32, device=self.device)
                    history_motion[..., :self.cfg.history_length] = motion_tensor_normalized[..., :self.cfg.history_length]

                    for primitive_idx in range(self.num_primitive):
                        start_idx = primitive_idx * gender_batch_size
                        end_idx = (primitive_idx + 1) * gender_batch_size
                        primitive_texts = gender_seq_texts[start_idx:end_idx]
                        unseen_texts = [text for text in primitive_texts if text not in self.text_embedding_dict]
                        if len(unseen_texts) > 0:
                            self.update_text_embedding_dict(unseen_texts, text_sep=self.text_sep, max_segs=self.max_segs)
                        text_embedding = torch.stack([self.text_embedding_dict[text] for text in primitive_texts], dim=0)  # [B, 512]
                        if self.text_sep:
                            text_mask = torch.stack([self.text_mask_dict[text] for text in primitive_texts], dim=0)  # [B, max_segs]
                        else:
                            text_mask = None
                        gender_batch.append(
                            {
                                'texts': primitive_texts,
                                'text_embedding': text_embedding,
                                'text_mask': text_mask,
                                'gender': [gender_seq_dict[person]['gender']] * gender_batch_size,
                                'betas': gender_seq_dict[person]['betas'][start_idx:end_idx, :-1, :10],
                                'motion_tensor_normalized': motion_tensor_normalized[start_idx:end_idx, ...], # [B, D, 1, T]
                                'history_motion': history_motion[start_idx:end_idx, ...],
                                'history_mask': history_mask[start_idx:end_idx, ...],
                                'history_length': self.cfg.history_length,
                                'future_length': self.cfg.future_length,
                            }
                        )
                
                selector = torch.cat([torch.ones(gender_batch_size), torch.zeros(gender_batch_size)])
                selector = selector[torch.randperm(2 * gender_batch_size)]
                
                front_group, back_group = {}, {}
                for key in add_key_list:
                    front_group[key], back_group[key] = [], []
                    for d in gender_batch[:self.num_primitive]:
                        front_group[key] += d[key]
                    for d in gender_batch[self.num_primitive:]:
                        back_group[key] += d[key]
                for key in cat_key_list:
                    front_group[key] = torch.cat([d[key] for d in gender_batch[:self.num_primitive]], dim=0)
                    back_group[key] = torch.cat([d[key] for d in gender_batch[self.num_primitive:]], dim=0)

                front_indices = torch.nonzero(selector[:gender_batch_size], as_tuple=True)[0]  
                back_indices = torch.nonzero(selector[gender_batch_size:], as_tuple=True)[0]  

                selected_batch = []
                for i in range(self.num_primitive):
                    selected_dict = {'history_length': self.cfg.history_length,
                                'future_length': self.cfg.future_length,}
                    for key in front_group.keys():    
                        if key in add_key_list:
                            selected_front = [front_group[key][i] for i in front_indices + i * gender_batch_size] 
                            selected_back = [back_group[key][i] for i in back_indices + i * gender_batch_size]
                            selected_dict[key] = selected_front + selected_back
                        elif key in cat_key_list:
                            selected_front = front_group[key][front_indices + i * gender_batch_size] 
                            selected_back = back_group[key][back_indices + i * gender_batch_size]
                            selected_dict[key] = torch.cat([selected_front, selected_back], dim=0)  
                    selected_batch.append(selected_dict)
                gender_batch = selected_batch
                
            elif self.mode == 'merged':
                # motion_tensor_normalized = self.normalize(torch.cat((self.dict_to_tensor(feature_dict['person1']), self.dict_to_tensor(feature_dict['person2'])), dim=-1))  # [B*num_mp, T, 2*D]
                motion_tensor_normalized = torch.cat((self.normalize(self.dict_to_tensor(feature_dict['person1'])), 
                                                     self.normalize(self.dict_to_tensor(feature_dict['person2']))), dim=-1)
                motion_tensor_normalized = motion_tensor_normalized.permute(0, 2, 1).unsqueeze(2)  # [B*num_mp, 2*D, 1, T]
                history_mask = torch.zeros_like(motion_tensor_normalized, dtype=torch.bool, device=self.device)
                history_mask[..., :self.cfg.history_length] = True
                history_motion = torch.zeros_like(motion_tensor_normalized, dtype=torch.float32, device=self.device)
                history_motion[..., :self.cfg.history_length] = motion_tensor_normalized[..., :self.cfg.history_length]

                for primitive_idx in range(self.num_primitive):
                    start_idx = primitive_idx * gender_batch_size
                    end_idx = (primitive_idx + 1) * gender_batch_size
                    primitive_texts = gender_seq_texts[start_idx:end_idx]
                    unseen_texts = [text for text in primitive_texts if text not in self.text_embedding_dict]
                    if len(unseen_texts) > 0:
                        self.update_text_embedding_dict(unseen_texts, text_sep=self.text_sep, max_segs=self.max_segs)
                    text_embedding = torch.stack([self.text_embedding_dict[text] for text in primitive_texts], dim=0)  # [B, 512]
                    if self.text_sep:
                        text_mask = torch.stack([self.text_mask_dict[text] for text in primitive_texts], dim=0)  # [B, max_segs]
                    else:
                        text_mask = None
                    gender_batch.append(
                        {
                            'texts': primitive_texts,
                            'text_embedding': text_embedding,
                            'text_mask': text_mask,
                            'gender_p1': [gender_seq_dict['person1']['gender']] * gender_batch_size,
                            'betas_p1': gender_seq_dict['person1']['betas'][start_idx:end_idx, :-1, :10],
                            'gender_p2': [gender_seq_dict['person2']['gender']] * gender_batch_size,
                            'betas_p2': gender_seq_dict['person2']['betas'][start_idx:end_idx, :-1, :10],
                            'motion_tensor_normalized': motion_tensor_normalized[start_idx:end_idx, ...], # [B, 2*D, 1, T]
                            'history_motion': history_motion[start_idx:end_idx, ...],
                            'history_mask': history_mask[start_idx:end_idx, ...],
                            'history_length': self.cfg.history_length,
                            'future_length': self.cfg.future_length,
                        }
                    )

            
            if batch is None:
                batch = gender_batch
            else:  # concatenate different gender batch
                for primitive_idx in range(self.num_primitive):
                    for key in add_key_list:
                        batch[primitive_idx][key] = batch[primitive_idx][key] + gender_batch[primitive_idx][key]
                    for key in cat_key_list:
                        batch[primitive_idx][key] = torch.cat([batch[primitive_idx][key], gender_batch[primitive_idx][key]], dim=0)
                    if self.text_sep:
                        batch[primitive_idx]['text_mask'] = torch.cat([batch[primitive_idx]['text_mask'], gender_batch[primitive_idx]['text_mask']], dim=0)
                    else:
                        batch[primitive_idx]['text_mask'] = None

        return batch

# dataset = InterHumanDatasetV2(enforce_gender='male',
#                             enforce_zero_beta=1,
#                             device='cuda:2',
#                             mode='merged',
#                             text_encoder='clip',
#                             text_sep=False,)

# batch_test = dataset.get_batch(batch_size=4)


class InterHumanDatasetV3(WeightedPrimitiveSequenceDataset):
    def __init__(self, dataset_name='interhuman',
                 dataset_path='./data/InterHuman/seq_data_single_interaction_zero_male_fps20',
                 cfg_path='./config_files/config_hydra/motion_primitive/hml_mp_h2_f8_r4.yaml',
                 split="train",
                 device='cuda',
                 weight_scheme='uniform',
                 prob_static=0.0,
                 enforce_gender=None,
                 enforce_zero_beta=None,
                 load_data=True,
                 text_tolerance=0.0,
                 body_type='smplh',
                 seed_only=False,
                 use_frame_weights=True,
                 mode='merged', # 'sep' or 'merged'
                 text_sep = False,
                 max_segs = 20,
                 **kwargs):
        self.dataset_name = dataset_name
        self.dataset_path = dataset_path
        self.split = split
        self.device = device
        self.weight_scheme = weight_scheme
        self.prob_static = prob_static
        self.enforce_gender = enforce_gender
        self.enforce_zero_beta = enforce_zero_beta
        self.text_tolerance = text_tolerance
        self.seed_only = seed_only
        self.mode = mode
        self.text_sep = text_sep
        self.max_segs = max_segs
        self.use_interaction_model = kwargs.get('use_interaction_model', False)
        self.key_list = ['person1', 'person2', 'interaction'] if self.mode=='merged' else ['person1', 'person2']
        print('enforce_gender: ', enforce_gender)
        print('enforce_zero_beta: ', enforce_zero_beta)

        motion_repr = {'transl': 3,
                       'poses_6d': 22 * 6,
                       'transl_delta': 3,
                       'global_orient_delta_6d': 6,
                       'joints': 22 * 3,
                       'joints_delta': 22 * 3,
                       }
        self.primitive_utility = PrimitiveUtility(device=self.device, body_type=body_type, motion_repr=motion_repr)
        self.motion_repr = self.primitive_utility.motion_repr

        # cfg_path = Path(dataset_path, 'config.yaml')
        with open(cfg_path, 'r') as f:
            self.cfg = OmegaConf.load(f)
        self.target_fps = self.cfg.fps
        # self.downsample_rate = 120 // self.target_fps
        self.history_length = self.cfg.history_length
        self.future_length = self.cfg.future_length
        self.primitive_length = self.history_length + self.future_length
        self.num_primitive = self.cfg.num_primitive
        if seed_only:
            assert self.num_primitive == 1
        self.seq_length = self.history_length + self.future_length * self.num_primitive + 1

        if load_data:
            with open(pjoin(dataset_path, f'{split}.pkl'), 'rb') as f:
                dataset = pickle.load(f)
            dataset = [data for data in dataset if len(data['motion_p1']['trans']) >= self.seq_length]
            
            elements_to_remove = ['7220', '7221', '6028', '7543', '6940', '4434', '7561', '4385']
            dataset = [data for data in dataset if data['seq_name'] not in elements_to_remove]

            for data in dataset:
                assert self.enforce_gender == data['motion_p1']['gender']
                assert self.enforce_gender == data['motion_p2']['gender']
                assert self.enforce_zero_beta
                gender_p1 = self.enforce_gender if self.enforce_gender is not None else data['motion_p1']['gender']
                gender_p2 = self.enforce_gender if self.enforce_gender is not None else data['motion_p2']['gender']
                betas_p1 =torch.from_numpy(data['motion_p1']['betas'].astype(np.float32))
                betas_p2 =torch.from_numpy(data['motion_p2']['betas'].astype(np.float32))
                if self.enforce_zero_beta:
                    betas_p1 = torch.zeros_like(betas_p1)
                    betas_p2 = torch.zeros_like(betas_p2)
                
                transl_p1 = torch.from_numpy(data['motion_p1']['trans'].astype(np.float32))
                poses_p1 = torch.from_numpy(data['motion_p1']['poses'].astype(np.float32))
                transl_p2 = torch.from_numpy(data['motion_p2']['trans'].astype(np.float32))
                poses_p2 = torch.from_numpy(data['motion_p2']['poses'].astype(np.float32))
                
                global_orient_p1 = transforms.axis_angle_to_matrix(poses_p1[:, :3])  # [T, 3, 3]
                body_pose_p1 = transforms.axis_angle_to_matrix(poses_p1[:, 3:66].reshape(-1, 21, 3))  # [T, 21, 3, 3]
                pelvis_delta_p1 = torch.from_numpy(data['motion_p1']['pelvis_delta'].astype(np.float32))  # [3]
                joints_p1 = torch.from_numpy(data['motion_p1']['joints'].astype(np.float32))  # [T, 22, 3]
                global_orient_p2 = transforms.axis_angle_to_matrix(poses_p2[:, :3])  # [T, 3, 3]
                body_pose_p2 = transforms.axis_angle_to_matrix(poses_p2[:, 3:66].reshape(-1, 21, 3))  # [T, 21, 3, 3]
                pelvis_delta_p2 = torch.from_numpy(data['motion_p2']['pelvis_delta'].astype(np.float32))  # [3]
                joints_p2 = torch.from_numpy(data['motion_p2']['joints'].astype(np.float32))  # [T, 22, 3]
                
                data['motion_p1'] = {
                    'gender': gender_p1,
                    'betas': betas_p1,
                    'transl': transl_p1,
                    'global_orient': global_orient_p1,
                    'body_pose': body_pose_p1,
                    'pelvis_delta': pelvis_delta_p1,
                    'joints': joints_p1,
                }
                data['motion_p2'] = {
                    'gender': gender_p2,
                    'betas': betas_p2,
                    'transl': transl_p2,
                    'global_orient': global_orient_p2,
                    'body_pose': body_pose_p2,
                    'pelvis_delta': pelvis_delta_p2,
                    'joints': joints_p2,
                }
            print('num of sequences: ', len(dataset))
            
            # assign sampling weights to each sequence
            for data in dataset:
                if 'uniform' in weight_scheme:
                    data['weight'] = 1.0
                elif 'length' in weight_scheme:
                    data['weight'] = len(data['motion_p1']['trans'])
            print('finish first assigning seq weights')

            # overfit using one sequence
            # if 'overfit' in weight_scheme:
            #     seq_id = int(weight_scheme.split('overfit:')[-1].split('_')[0])
            #     for idx, data in enumerate(dataset):
            #         if idx == seq_id:
            #             data['weight'] = 1.0
            #         else:
            #             data['weight'] = 0.0

            seq_weights = np.array([data['weight'] for data in dataset])
            seq_weights = seq_weights / seq_weights.sum()

            self.dataset = dataset
            self.seq_weights = seq_weights

        # load or calc mean and std
        self.tensor_mean_device_dict = {}
        file_name = f'mean_std_h{self.history_length}_f{self.future_length}'
        
        mean_std_path = Path(dataset_path, f'{file_name}.pkl')
        if self.use_interaction_model:
            mean_std_interaction_path = Path(dataset_path, f'{file_name}_interaction.pkl')
        if mean_std_path.exists() and (not self.use_interaction_model or mean_std_interaction_path.exists()):
            print(f'loading mean and std from {mean_std_path}')
            with open(mean_std_path, 'rb') as f:
                self.tensor_mean, self.tensor_std = pickle.load(f)  # [1, 1, D]

            if self.use_interaction_model:
                print(f'loading interaction mean and std from {mean_std_interaction_path}')
                with open(mean_std_interaction_path, 'rb') as f:
                    self.rel_mean, self.rel_std = pickle.load(f)  # [1, 1, D]
        else:
            assert self.split == 'train'
            print('calculating mean and std using train split')
            result = self.calc_mean_std()

            if self.use_interaction_model:
                self.tensor_mean, self.tensor_std, self.rel_mean, self.rel_std = result
            else:
                self.tensor_mean, self.tensor_std = result
                self.rel_mean, self.rel_std = None, None

            with open(mean_std_path, 'wb') as f:
                pickle.dump((self.tensor_mean.detach().cpu(), self.tensor_std.detach().cpu()), f)

            if self.use_interaction_model:
                with open(mean_std_interaction_path, 'wb') as f:
                    pickle.dump((self.rel_mean.detach().cpu(), self.rel_std.detach().cpu()), f)
        

        # load clip model, get train text embeddings
        self.clip_model = load_and_freeze_clip(clip_version='ViT-B/32', device=self.device)
        self.embedding_path = {}
        embedding_path = {}
        for key_type in self.key_list:
            if text_sep:
                self.embedding_path[key_type] = embedding_path[key_type] = Path(dataset_path, f'{split}_{key_type}_text_embedding_dict_textsep.pkl')
            else:
                self.embedding_path[key_type] = embedding_path[key_type] = Path(dataset_path, f'{split}_{key_type}_text_embedding_dict.pkl')
        self.text_embedding_dict = {}
        if text_sep:
            self.text_mask_dict = {}
        
        for key_type in self.key_list:
            if embedding_path[key_type].exists():
                print(f"Loading text_{key_type} embeddings from {embedding_path[key_type]}!")
                with open(embedding_path[key_type], 'rb') as f:
                    self.text_embedding_dict[key_type] = pickle.load(f)
                if text_sep:
                    with open(Path(str(embedding_path[key_type]).replace('text_embedding_dict', 'text_mask_dict')), 'rb') as f:
                        self.text_mask_dict[key_type] = pickle.load(f)
            else:
                print('Calculating text embeddings')
                raw_texts = []
                for data in self.dataset:
                    if f'frame_labels_{key_type}' in data:
                        raw_texts.extend([seg['proc_label'] for seg in data['frame_labels_' + key_type]])

                raw_texts = list(set(raw_texts))
                num_texts = len(raw_texts)
                print(f'num of unique texts_{key_type}: ', len(raw_texts))
                    
                # get text embeddings by batch
                text_embeddings = []
                text_mask = []
                batch_start_idx = 0
                while batch_start_idx < num_texts:
                    batch_end_idx = min(batch_start_idx + 256, num_texts)
                    text_embeddings_temp = encode_text(self.clip_model, raw_texts[batch_start_idx:batch_end_idx], text_sep=text_sep, max_segs=max_segs)
                    if text_sep:
                        text_embeddings.append(text_embeddings_temp[0])
                        text_mask.append(text_embeddings_temp[1])
                    else:
                        text_embeddings.append(text_embeddings_temp)
                    batch_start_idx = batch_end_idx
                text_embeddings = torch.cat(text_embeddings, dim=0).detach().cpu().numpy()
            
                self.text_embedding_dict[key_type] = {raw_texts[idx]: text_embeddings[idx] for idx in range(num_texts)}
                if text_sep:
                    self.text_embedding_dict[''] = np.zeros((self.max_segs, 512)).astype(np.float32)
                else:
                    self.text_embedding_dict[''] = np.zeros(512).astype(np.float32)  # for empty text have zero embedding, compatible with mdm text masking
                with open(embedding_path[key_type], 'wb') as f:
                    pickle.dump(self.text_embedding_dict[key_type], f)
                if text_sep:
                    text_mask = torch.cat(text_mask, dim=0).detach().cpu().numpy()
                    self.text_mask_dict[key_type] = {raw_texts[idx]: text_mask[idx] for idx in range(num_texts)}
                    self.text_mask_dict[key_type][''] = np.zeros(max_segs).astype(np.bool_)  # for empty text have zero embedding, compatible with mdm text masking
                    with open(Path(str(embedding_path[key_type]).replace('text_embedding_dict', 'text_mask_dict')), 'wb') as f:
                        pickle.dump(self.text_mask_dict[key_type], f)
            
            for key in self.text_embedding_dict[key_type]:
                self.text_embedding_dict[key_type][key] = torch.from_numpy(self.text_embedding_dict[key_type][key]).to(dtype=torch.float32, device=self.device)
                if text_sep:
                    self.text_mask_dict[key_type][key] = torch.from_numpy(self.text_mask_dict[key_type][key]).to(dtype=torch.bool, device=self.device)

    def update_text_embedding_dict(self, new_texts, key_type, text_sep=False, max_segs=20):
        new_text_embeddings = encode_text(self.clip_model, new_texts, text_sep=text_sep, max_segs=max_segs)
        for idx, text in enumerate(new_texts):
            if text_sep:
                self.text_embedding_dict[key_type][text] = new_text_embeddings[idx][0]
                self.text_mask_dict[key_type][text] = new_text_embeddings[idx][1]
            else:
                self.text_embedding_dict[key_type][text] = new_text_embeddings[idx]

    def calc_mean_std(self, batch_size=512):
        if self.future_length == 1:
            batch_size = min(batch_size, 64)
        all_mp_data = []
        if self.use_interaction_model:
            all_rel_info = []
        for seq_data in self.dataset:
            motion_data_p1 = seq_data['motion_p1']
            num_frames = motion_data_p1['transl'].shape[0]
            primitive_data_list = []
            for start_frame in range(0, num_frames - self.primitive_length, self.future_length):
                end_frame = start_frame + self.primitive_length
                primitive_data_list.append(self.get_primitive(seq_data, start_frame, end_frame, skip_text=True))
                
            primitive_dict = {}
            for person in ['person1', 'person2']:
                primitive_dict[person] = {}
                primitive_dict[person]['gender'] = {primitive_data_list[0]['primitive_dict'][person]['gender']}
                for key in ['betas', 'transl', 'global_orient', 'body_pose', 'transf_rotmat', 'transf_transl', 'pelvis_delta', 'joints']:
                    primitive_dict[person][key] = torch.cat([data['primitive_dict'][person][key] for data in primitive_data_list], dim=0)
                primitive_dict[person] = tensor_dict_to_device(primitive_dict[person], self.device)

            # split primitive_dict into batches
            batch_start_idx = 0
            while batch_start_idx < len(primitive_dict['person1']['transl']):
                batch_primitive_dict = {}
                canonicalized_primitive_dict = {}
                batch_end_idx = min(batch_start_idx + batch_size, len(primitive_dict['person1']['transl']))
                for person in ['person1', 'person2']:
                    batch_primitive_dict[person] = {}
                    batch_primitive_dict[person] = {key: primitive_dict[person][key][batch_start_idx:batch_end_idx] for key in ['betas', 'transl', 'global_orient', 'body_pose', 'transf_rotmat', 'transf_transl', 'pelvis_delta', 'joints']}
                    batch_primitive_dict[person]['gender'] = primitive_dict[person]['gender']
                    # transf_rotmat[person], transf_transl[person], canonicalized_primitive_dict[person] = self.primitive_utility.canonicalize(batch_primitive_dict[person], use_predicted_joints=True)
                    _, _, canonicalized_primitive_dict[person] = self.primitive_utility.canonicalize(copy.deepcopy(batch_primitive_dict[person]), use_predicted_joints=True)

                feature_dict = {}
                motion_tensor = {}
                for person in ['person1', 'person2']:
                    feature_dict[person] = self.primitive_utility.calc_features(canonicalized_primitive_dict[person], use_predicted_joints=True)
                    feature_dict[person]['transl'] = feature_dict[person]['transl'][:, :-1, :]          # [num_primitive, T, 3]
                    feature_dict[person]['poses_6d'] = feature_dict[person]['poses_6d'][:, :-1, :]      # [num_primitive, T, 66]
                    feature_dict[person]['joints'] = feature_dict[person]['joints'][:, :-1, :]          # [num_primitive, T, 22 * 3]
                    # feature_dict[person]['transf_rotmat_6d'] = transforms.matrix_to_rotation_6d(transf_rotmat[person])
                    # feature_dict[person]['transf_transl'] = transf_transl[person]
                    motion_tensor[person] = self.dict_to_tensor(feature_dict[person]).detach().cpu()    # [num_primitive, T, D]
                    all_mp_data.append(motion_tensor[person])                                           # [num_primitive, T, D]

                if self.use_interaction_model:
                    joints_p1 = batch_primitive_dict['person1']['joints'].reshape(-1, batch_primitive_dict['person1']['joints'].shape[1], 22, 3)  # [B, T+1, 22, 3]
                    joints_p2 = batch_primitive_dict['person2']['joints'].reshape(-1, batch_primitive_dict['person2']['joints'].shape[1], 22, 3)
                    global_orient_p1 = batch_primitive_dict['person1']['global_orient']  # [B, T, 3, 3]
                    global_orient_p2 = batch_primitive_dict['person2']['global_orient']
                    transl_p1 = batch_primitive_dict['person1']['transl']  # [B, T, 3]
                    transl_p2 = batch_primitive_dict['person2']['transl']

                    rel_global_orient = {
                        'b2a': global_orient_p1.transpose(-1, -2) @ global_orient_p2,
                        'a2b': global_orient_p2.transpose(-1, -2) @ global_orient_p1,
                    }
                    rel_root_transl = {
                        'b2a': torch.matmul(global_orient_p1.transpose(-1, -2), (transl_p2 - transl_p1).unsqueeze(-1)).squeeze(-1),
                        'a2b': torch.matmul(global_orient_p2.transpose(-1, -2), (transl_p1 - transl_p2).unsqueeze(-1)).squeeze(-1),
                    }

                    dists = torch.norm(joints_p1.unsqueeze(3) - joints_p2.unsqueeze(2), dim=-1)  # [B, T+1, 22, 22]
                    rel_mindis = {
                        'b2a': dists.min(dim=-1).values,  # [B, T+1, 22]
                        'a2b': dists.min(dim=-2).values,
                    }

                    for key in ['b2a', 'a2b']:
                        rot_6d = transforms.matrix_to_rotation_6d(rel_global_orient[key])  # [B, T+1, 6]
                        rel_info = torch.cat([rot_6d, rel_root_transl[key], rel_mindis[key]], dim=-1)  # [B, T+1, 6+3+22]
                        rel_info = rel_info[:, self.cfg.history_length:-1, :]  # [B, T-(1+history_length), D]
                        all_rel_info.append(rel_info.detach().cpu())

                batch_start_idx = batch_end_idx

        all_mp_data = torch.cat(all_mp_data, dim=0)                 # [2*N, T, D]
        tensor_mean = all_mp_data.mean(dim=[0, 1], keepdim=True)    # [1, 1, D]
        tensor_std = all_mp_data.std(dim=[0, 1], keepdim=True)      # [1, 1, D]
        if self.use_interaction_model and len(all_rel_info) > 0:
            all_rel_info = torch.cat(all_rel_info, dim=0)  # [2N, (T-(1+history_length)), D]
            rel_mean = all_rel_info.mean(dim=[0, 1], keepdim=True)
            rel_std = all_rel_info.std(dim=[0, 1], keepdim=True)
        if self.use_interaction_model:
            return tensor_mean.to(self.device), tensor_std.to(self.device), rel_mean.to(self.device), rel_std.to(self.device)
        return tensor_mean.to(self.device), tensor_std.to(self.device)

    def get_primitive(self, seq_data, start_frame, end_frame, skip_text=False):
        """end_frame included"""
        motion_data_p1 = seq_data['motion_p1']
        motion_data_p2 = seq_data['motion_p2']
        primitive_dict = {}
        primitive_dict['person1'] = {
            'gender': motion_data_p1['gender'],
            'betas': motion_data_p1['betas'].expand(1, self.primitive_length + 1, 10),
            'transl': motion_data_p1['transl'][start_frame:end_frame + 1].unsqueeze(0),  # include one more frame for delta feature calculation
            'global_orient': motion_data_p1['global_orient'][start_frame:end_frame + 1].unsqueeze(0),
            'body_pose': motion_data_p1['body_pose'][start_frame:end_frame + 1].unsqueeze(0),
            'pelvis_delta': motion_data_p1['pelvis_delta'].unsqueeze(0),
            'joints': motion_data_p1['joints'][start_frame:end_frame + 1].unsqueeze(0),
            'transf_rotmat': torch.eye(3).unsqueeze(0),
            'transf_transl': torch.zeros(1, 1, 3),
        }
        primitive_dict['person2'] = {
            'gender': motion_data_p2['gender'],
            'betas': motion_data_p2['betas'].expand(1, self.primitive_length + 1, 10),
            'transl': motion_data_p2['transl'][start_frame:end_frame + 1].unsqueeze(0),  # include one more frame for delta feature calculation
            'global_orient': motion_data_p2['global_orient'][start_frame:end_frame + 1].unsqueeze(0),
            'body_pose': motion_data_p2['body_pose'][start_frame:end_frame + 1].unsqueeze(0),
            'pelvis_delta': motion_data_p2['pelvis_delta'].unsqueeze(0),
            'joints': motion_data_p2['joints'][start_frame:end_frame + 1].unsqueeze(0),
            'transf_rotmat': torch.eye(3).unsqueeze(0),
            'transf_transl': torch.zeros(1, 1, 3),
        }

        texts = {key: [] for key in self.key_list}
        for key_type in self.key_list:
            if not skip_text and f'frame_labels_{key_type}' in seq_data:
                future_start = (start_frame + self.history_length) / self.target_fps
                future_end = (start_frame + self.history_length + self.future_length - 1) / self.target_fps
                # print('text tolerance: ', self.text_tolerance)
                for seg in seq_data[f'frame_labels_{key_type}']:
                    if have_overlap([seg['start_t'], seg['end_t']], [future_start - self.text_tolerance, future_end + self.text_tolerance]):
                        texts[key_type].append(seg['proc_label'])

        output = {}
        for key_type in self.key_list:
            output['text_'+key_type] = random.choice(texts[key_type]) if len(texts[key_type]) > 0 else ''
        output['primitive_dict'] = primitive_dict
        return output

    def get_rel_mean_std_by_device(self, device):
        if not hasattr(self, 'rel_mean_device_dict'):
            self.rel_mean_device_dict = {}

        if device not in self.rel_mean_device_dict:
            assert self.rel_mean is not None and self.rel_std is not None, "rel_mean/std must be computed before normalization."
            self.rel_mean_device_dict[device] = (
                self.rel_mean.to(device=device),
                self.rel_std.to(device=device)
            )
        return self.rel_mean_device_dict[device]

    def normalize_rel_info(self, rel_info: torch.Tensor) -> torch.Tensor:
        """
        Standardize interaction feature tensor using rel_mean / rel_std
        rel_info: Tensor of shape [B, D] or [B, T, D]
        """
        rel_mean, rel_std = self.get_rel_mean_std_by_device(rel_info.device)
        return (rel_info - rel_mean) / rel_std

    def get_batch(self, batch_size=8):
        self.time = time.time()
        seq_list = []
        batch_idx = self.get_batch_idx(batch_size)
        add_key_list = ['texts', 'gender']
        cat_key_list = ['betas', 'motion_tensor_normalized', 'history_motion', 'history_mask', 'text_embedding', 'transf_rotmat', 'transf_transl']
        if self.text_sep:
            cat_key_list.append('text_mask')
        
        for seq_idx in batch_idx:
            seq_data = self.dataset[seq_idx]
            num_frames = len(seq_data['motion_p1']['transl'])
            if 'text' in self.weight_scheme:
                start_frame = random.choices(range(num_frames - self.seq_length + 1), weights=seq_data['frame_weights'], k=1)[0]
            else:
                start_frame = random.randint(0, num_frames - self.seq_length)  # [0, num_frames - seq_length], right end inclusive
            primitive_data_list = []
            for frame_idx in range(start_frame, start_frame + self.seq_length - self.primitive_length, self.future_length):
                primitive_data = self.get_primitive(seq_data, frame_idx, frame_idx + self.primitive_length)
                primitive_data_list.append(primitive_data)
            seq_list.append(primitive_data_list)

        # sort batch by gender
        batch = None
        for gender in ['female', 'male']:
            gender_idx = [idx for idx in range(len(seq_list)) if seq_list[idx][0]['primitive_dict']['person1']['gender'] == gender]
            if len(gender_idx) == 0:
                continue
            gender_seq_list = [seq_list[i] for i in gender_idx]
            gender_batch_size = len(gender_idx)
            gender_batch = {} if self.mode == 'merged' else []
            
            gender_seq_texts = {key_type: None for key_type in self.key_list}
            gender_seq_dict = None
            for primitive_idx in range(self.num_primitive):
                primitive_dict = {}
                for person in ['person1', 'person2']:
                    primitive_dict[person] = {}
                    primitive_dict[person]['gender'] = gender
                    for key in ['betas', 'transl', 'global_orient', 'body_pose', 'transf_rotmat', 'transf_transl', 'pelvis_delta', 'joints']:
                        primitive_dict[person][key] = torch.cat([mp_seq[primitive_idx]['primitive_dict'][person][key] for mp_seq in gender_seq_list], dim=0)
                primitive_texts = {}
                for key_type in self.key_list:
                    primitive_texts[key_type] = [mp_seq[primitive_idx]['text_'+key_type] for mp_seq in gender_seq_list]
                    gender_seq_texts[key_type] = primitive_texts[key_type] if gender_seq_texts[key_type] is None else gender_seq_texts[key_type] + primitive_texts[key_type]
                
                if gender_seq_dict is None:
                    gender_seq_dict = primitive_dict
                else:
                    for person in ['person1', 'person2']:
                        for key in ['betas', 'transl', 'global_orient', 'body_pose', 'transf_rotmat', 'transf_transl', 'pelvis_delta', 'joints']:
                            gender_seq_dict[person][key] = torch.cat([gender_seq_dict[person][key], primitive_dict[person][key]], dim=0)

            canonicalized_primitive_dict = {}
            if self.mode == 'merged':
                transf_rotmat, transf_transl = {}, {}
            for person in ['person1', 'person2']:
                gender_seq_dict[person] = tensor_dict_to_device(gender_seq_dict[person], self.device)
                if self.mode == 'merged':
                    transf_rotmat[person], transf_transl[person], canonicalized_primitive_dict[person] = self.primitive_utility.canonicalize(copy.deepcopy(gender_seq_dict[person]), use_predicted_joints=True)
                else:
                    _, _, canonicalized_primitive_dict[person] = self.primitive_utility.canonicalize(copy.deepcopy(gender_seq_dict[person]), use_predicted_joints=True)
            
            
            if self.mode == 'merged':
                # rel_pose
                rel_rotmat, rel_transl, rel_pose = {}, {}, {}
                rel_rotmat['b2a'], rel_transl['b2a'] = self.primitive_utility.compute_rel_transform_B_in_A(transf_rotmat['person1'], transf_transl['person1'], transf_rotmat['person2'], transf_transl['person2'])
                rel_rotmat['a2b'], rel_transl['a2b'] = self.primitive_utility.compute_rel_transform_B_in_A(transf_rotmat['person2'], transf_transl['person2'], transf_rotmat['person1'], transf_transl['person1'])
                for rel in ['b2a', 'a2b']:
                    rel_rotmat[rel] = transforms.matrix_to_rotation_6d(rel_rotmat[rel])
                    rel_pose[rel] = torch.cat([rel_rotmat[rel], rel_transl[rel].squeeze(1)], dim=-1)  # [B*num_mp, 6+3]
            
                if self.use_interaction_model:
                    # reltive transition, relative distance
                    rel_global_orient, rel_root_transl, rel_mindis = {}, {}, {}
                    rel_global_orient['b2a'] = gender_seq_dict['person1']['global_orient'].transpose(-1, -2) @ gender_seq_dict['person2']['global_orient']
                    rel_global_orient['a2b'] = gender_seq_dict['person2']['global_orient'].transpose(-1, -2) @ gender_seq_dict['person1']['global_orient']
                    rel_root_transl['b2a'] = torch.matmul(gender_seq_dict['person1']['global_orient'].transpose(-1, -2), (gender_seq_dict['person2']['transl']-gender_seq_dict['person1']['transl']).unsqueeze(-1)).squeeze(-1)
                    rel_root_transl['a2b'] = torch.matmul(gender_seq_dict['person2']['global_orient'].transpose(-1, -2), (gender_seq_dict['person1']['transl']-gender_seq_dict['person2']['transl']).unsqueeze(-1)).squeeze(-1)
                    
                    dists = torch.norm(gender_seq_dict['person1']['joints'].unsqueeze(3)-gender_seq_dict['person2']['joints'].unsqueeze(2), dim=-1)
                    rel_mindis['b2a'], _ = dists.min(dim=-1)
                    rel_mindis['a2b'], _ = dists.min(dim=-2)
                    
                    rel_info = {}
                    for key in ['b2a', 'a2b']:
                        rel_info[key] = torch.cat([transforms.matrix_to_rotation_6d(rel_global_orient[key]), rel_root_transl[key], rel_mindis[key]], dim=-1)  # [B*num_mp, T, 6+3+22]
                
            # calc features
            feature_dict = {}
            for person in ['person1', 'person2']:
                feature_dict[person] = self.primitive_utility.calc_features(canonicalized_primitive_dict[person], use_predicted_joints=True)
                feature_dict[person]['transl'] = feature_dict[person]['transl'][:, :-1, :]  # [B*num_mp, T, 3]
                feature_dict[person]['poses_6d'] = feature_dict[person]['poses_6d'][:, :-1, :]  # [B*num_mp, T, 66]
                feature_dict[person]['joints'] = feature_dict[person]['joints'][:, :-1, :]  # [B*num_mp, T, 22 * 3]
            
            
            if self.mode == 'merged':
                for person in ['person1', 'person2']:
                    gender_batch[person] = []
                    motion_tensor_normalized = self.normalize(self.dict_to_tensor(feature_dict[person]))     # [B*num_mp, T, D]
                    motion_tensor_normalized = motion_tensor_normalized.permute(0, 2, 1).unsqueeze(2)           # [B*num_mp, D, 1, T]
                    history_mask = torch.zeros_like(motion_tensor_normalized, dtype=torch.bool, device=self.device)
                    history_mask[..., :self.cfg.history_length] = True
                    history_motion = torch.zeros_like(motion_tensor_normalized, dtype=torch.float32, device=self.device)
                    history_motion[..., :self.cfg.history_length] = motion_tensor_normalized[..., :self.cfg.history_length]

                    for primitive_idx in range(self.num_primitive):
                        start_idx = primitive_idx * gender_batch_size
                        end_idx = (primitive_idx + 1) * gender_batch_size
                        primitive_texts = gender_seq_texts[person][start_idx:end_idx]
                        unseen_texts = [text for text in primitive_texts if text not in self.text_embedding_dict[person]]
                        if len(unseen_texts) > 0:
                            self.update_text_embedding_dict(unseen_texts, person, text_sep=self.text_sep, max_segs=self.max_segs)
                        text_embedding = torch.stack([self.text_embedding_dict[person][text] for text in primitive_texts], dim=0)  # [B, 512]
                        if self.text_sep:
                            text_mask = torch.stack([self.text_mask_dict[person][text] for text in primitive_texts], dim=0)  # [B, max_segs]
                        else:
                            text_mask = None
               
                        gender_batch[person].append({
                                'texts': primitive_texts,
                                'text_embedding': text_embedding,
                                'text_mask': text_mask, 
                                'gender': [gender_seq_dict[person]['gender']] * gender_batch_size,
                                'betas': gender_seq_dict[person]['betas'][start_idx:end_idx, :-1, :10],
                                'motion_tensor_normalized': motion_tensor_normalized[start_idx:end_idx, ...], # [B, D, 1, T]
                                'history_motion': history_motion[start_idx:end_idx, ...],
                                'history_mask': history_mask[start_idx:end_idx, ...],
                                'transf_rotmat': transf_rotmat[person][start_idx:end_idx, ...],
                                'transf_transl': transf_transl[person][start_idx:end_idx, ...],
                                'history_length': self.cfg.history_length,
                                'future_length': self.cfg.future_length,
                            })
                gender_batch['interaction'] = []
                for primitive_idx in range(self.num_primitive):        
                    start_idx = primitive_idx * gender_batch_size
                    end_idx = (primitive_idx + 1) * gender_batch_size
                    primitive_texts = gender_seq_texts['interaction'][start_idx:end_idx]
                    unseen_texts = [text for text in primitive_texts if text not in self.text_embedding_dict['interaction']]
                    if len(unseen_texts) > 0:
                        self.update_text_embedding_dict(unseen_texts, 'interaction', text_sep=self.text_sep, max_segs=self.max_segs)
                    text_embedding = torch.stack([self.text_embedding_dict['interaction'][text] for text in primitive_texts], dim=0)  # [B, 512]
                    if self.text_sep:
                        text_mask = torch.stack([self.text_mask_dict['interaction'][text] for text in primitive_texts], dim=0)  # [B, max_segs]
                    else:
                        text_mask = None
                    gender_batch['interaction'].append({
                            'texts': primitive_texts,
                            'text_embedding': text_embedding,
                            'text_mask': text_mask, 
                            'rel_pose_b2a': rel_pose['b2a'][start_idx:end_idx],
                            'rel_pose_a2b': rel_pose['a2b'][start_idx:end_idx],
                        })
                    if self.use_interaction_model:
                        gender_batch['interaction'][-1].update({
                            'rel_info_b2a': self.normalize_rel_info(rel_info['b2a'][start_idx:end_idx, self.cfg.history_length:-1]),
                            'rel_info_a2b': self.normalize_rel_info(rel_info['a2b'][start_idx:end_idx, self.cfg.history_length:-1]),
                        })
                    
            elif self.mode == 'sep':
                for person in ['person1', 'person2']:
                    motion_tensor_normalized = self.normalize(self.dict_to_tensor(feature_dict[person]))     # [B*num_mp, T, D]
                    motion_tensor_normalized = motion_tensor_normalized.permute(0, 2, 1).unsqueeze(2)           # [B*num_mp, D, 1, T]
                    history_mask = torch.zeros_like(motion_tensor_normalized, dtype=torch.bool, device=self.device)
                    history_mask[..., :self.cfg.history_length] = True
                    history_motion = torch.zeros_like(motion_tensor_normalized, dtype=torch.float32, device=self.device)
                    history_motion[..., :self.cfg.history_length] = motion_tensor_normalized[..., :self.cfg.history_length]

                    for primitive_idx in range(self.num_primitive):
                        start_idx = primitive_idx * gender_batch_size
                        end_idx = (primitive_idx + 1) * gender_batch_size
                        primitive_texts = gender_seq_texts[person][start_idx:end_idx]
                        unseen_texts = [text for text in primitive_texts if text not in self.text_embedding_dict[person]]
                        if len(unseen_texts) > 0:
                            self.update_text_embedding_dict(unseen_texts, person, text_sep=self.text_sep, max_segs=self.max_segs)
                            # new_text_embeddings = encode_text(self.clip_model, unseen_texts)
                            # for idx, text in enumerate(unseen_texts):
                            #     self.text_embedding_dict[person][text] = new_text_embeddings[idx]
                        text_embedding = torch.stack([self.text_embedding_dict[person][text] for text in primitive_texts], dim=0)  # [B, 512]
                        if self.text_sep:
                            text_mask = torch.stack([self.text_mask_dict[person][text] for text in primitive_texts], dim=0)  # [B, max_segs]
                        else:
                            text_mask = None
                        gender_batch.append(
                            {
                                'texts': primitive_texts,
                                'text_embedding': text_embedding,
                                'text_mask': text_mask, 
                                'gender': [gender_seq_dict[person]['gender']] * gender_batch_size,
                                'betas': gender_seq_dict[person]['betas'][start_idx:end_idx, :-1, :10],
                                'motion_tensor_normalized': motion_tensor_normalized[start_idx:end_idx, ...], # [B, D, 1, T]
                                'history_motion': history_motion[start_idx:end_idx, ...],
                                'history_mask': history_mask[start_idx:end_idx, ...],
                                'history_length': self.cfg.history_length,
                                'future_length': self.cfg.future_length,
                            }
                        )
                selector = torch.cat([torch.ones(gender_batch_size), torch.zeros(gender_batch_size)])
                selector = selector[torch.randperm(2 * gender_batch_size)]
                
                front_group, back_group = {}, {}
                for key in add_key_list:
                    front_group[key], back_group[key] = [], []
                    for d in gender_batch[:self.num_primitive]:
                        front_group[key] += d[key]
                    for d in gender_batch[self.num_primitive:]:
                        back_group[key] += d[key]
                for key in cat_key_list:
                    front_group[key] = torch.cat([d[key] for d in gender_batch[:self.num_primitive]], dim=0)
                    back_group[key] = torch.cat([d[key] for d in gender_batch[self.num_primitive:]], dim=0)

                front_indices = torch.nonzero(selector[:gender_batch_size], as_tuple=True)[0]  
                back_indices = torch.nonzero(selector[gender_batch_size:], as_tuple=True)[0]  

                selected_batch = []
                for i in range(self.num_primitive):
                    selected_dict = {'history_length': self.cfg.history_length,
                                'future_length': self.cfg.future_length,}
                    for key in front_group.keys():    
                        if key in add_key_list:
                            selected_front = [front_group[key][i] for i in front_indices + i * gender_batch_size] 
                            selected_back = [back_group[key][i] for i in back_indices + i * gender_batch_size]
                            selected_dict[key] = selected_front + selected_back
                        elif key in cat_key_list:
                            selected_front = front_group[key][front_indices + i * gender_batch_size] 
                            selected_back = back_group[key][back_indices + i * gender_batch_size]
                            selected_dict[key] = torch.cat([selected_front, selected_back], dim=0)  
                    selected_batch.append(selected_dict)
                gender_batch = selected_batch
                            
            if batch is None:
                batch = gender_batch
            else:  # concatenate different gender batch
                for primitive_idx in range(self.num_primitive):
                    if self.mode == 'merged':
                        for key_type in self.key_list:
                            if key_type != 'interaction':
                                for key in add_key_list:
                                    batch[key_type][primitive_idx][key] = batch[key_type][primitive_idx][key] + gender_batch[key_type][primitive_idx][key]
                                for key in cat_key_list:
                                    batch[key_type][primitive_idx][key] = torch.cat([batch[key_type][primitive_idx][key], gender_batch[key_type][primitive_idx][key]], dim=0)
                            else:
                                for key in ['texts']:
                                    batch[key_type][primitive_idx][key] = batch[key_type][primitive_idx][key] + gender_batch[key_type][primitive_idx][key]
                                for key in ['text_embedding', 'text_mask', 'rel_pose_b2a', 'rel_pose_a2b', 'rel_info_b2a', 'rel_info_a2b']:
                                    batch[key_type][primitive_idx][key] = torch.cat([batch[key_type][primitive_idx][key], gender_batch[key_type][primitive_idx][key]], dim=0)
                    else:
                        for key in add_key_list:
                            batch[primitive_idx][key] = batch[primitive_idx][key] + gender_batch[primitive_idx][key]
                        for key in cat_key_list:
                            batch[primitive_idx][key] = torch.cat([batch[primitive_idx][key], gender_batch[primitive_idx][key]], dim=0)
        return batch
    
    def get_item(self, idx):
        seq_data = self.dataset[idx]
        num_frames = len(seq_data['motion_p1']['transl'])
        gender = {}
        gender['person1'] = seq_data['motion_p1']['gender']
        gender['person2'] = seq_data['motion_p2']['gender']
        if 'text' in self.weight_scheme:
            start_frame = random.choices(range(num_frames - self.seq_length + 1), weights=seq_data['frame_weights'], k=1)[0]
        else:
            start_frame = random.randint(0, num_frames - self.seq_length)  # [0, num_frames - seq_length], right end inclusive
        primitive_data_list = []
        for frame_idx in range(start_frame, start_frame + self.seq_length - self.primitive_length, self.future_length):
            primitive_data = self.get_primitive(seq_data, frame_idx, frame_idx + self.primitive_length)
            primitive_data_list.append(primitive_data)
        
        gender_seq_texts = {key_type: [] for key_type in self.key_list}
        gender_seq_dict = None
        for primitive_idx in range(self.num_primitive):
            primitive_dict = {}
            for person in ['person1', 'person2']:
                primitive_dict[person] = {}
                primitive_dict[person]['gender'] = gender[person]
                for key in ['betas', 'transl', 'global_orient', 'body_pose', 'transf_rotmat', 'transf_transl', 'pelvis_delta', 'joints']:
                    primitive_dict[person][key] = primitive_data_list[primitive_idx]['primitive_dict'][person][key]
            primitive_texts = {}
            for key_type in self.key_list:
                primitive_texts[key_type] = primitive_data_list[primitive_idx]['text_'+key_type]
                gender_seq_texts[key_type].append(primitive_texts[key_type])
            
            if gender_seq_dict is None:
                gender_seq_dict = primitive_dict
            else:
                for person in ['person1', 'person2']:
                    for key in ['betas', 'transl', 'global_orient', 'body_pose', 'transf_rotmat', 'transf_transl', 'pelvis_delta', 'joints']:
                        gender_seq_dict[person][key] = torch.cat([gender_seq_dict[person][key], primitive_dict[person][key]], dim=0)

        canonicalized_primitive_dict = {}
        if self.mode == 'merged':
            transf_rotmat, transf_transl = {}, {}
        for person in ['person1', 'person2']:
            gender_seq_dict[person] = tensor_dict_to_device(gender_seq_dict[person], self.device)
            if self.mode == 'merged':
                transf_rotmat[person], transf_transl[person], canonicalized_primitive_dict[person] = self.primitive_utility.canonicalize(copy.deepcopy(gender_seq_dict[person]), use_predicted_joints=True)
            else:
                _, _, canonicalized_primitive_dict[person] = self.primitive_utility.canonicalize(copy.deepcopy(gender_seq_dict[person]), use_predicted_joints=True)
        
        if self.mode == 'merged':
            # rel_pose
            rel_rotmat, rel_transl, rel_pose = {}, {}, {}
            rel_rotmat['b2a'], rel_transl['b2a'] = self.primitive_utility.compute_rel_transform_B_in_A(transf_rotmat['person1'], transf_transl['person1'], transf_rotmat['person2'], transf_transl['person2'])
            rel_rotmat['a2b'], rel_transl['a2b'] = self.primitive_utility.compute_rel_transform_B_in_A(transf_rotmat['person2'], transf_transl['person2'], transf_rotmat['person1'], transf_transl['person1'])
            for rel in ['b2a', 'a2b']:
                rel_rotmat[rel] = transforms.matrix_to_rotation_6d(rel_rotmat[rel])
                rel_pose[rel] = torch.cat([rel_rotmat[rel], rel_transl[rel].squeeze(1)], dim=-1)  # [B*num_mp, 6+3]
            
            if self.use_interaction_model:
                # reltive transition, relative distance
                rel_global_orient, rel_root_transl, rel_mindis = {}, {}, {}
                rel_global_orient['b2a'] = gender_seq_dict['person1']['global_orient'].transpose(-1, -2) @ gender_seq_dict['person2']['global_orient']
                rel_global_orient['a2b'] = gender_seq_dict['person2']['global_orient'].transpose(-1, -2) @ gender_seq_dict['person1']['global_orient']
                rel_root_transl['b2a'] = torch.matmul(gender_seq_dict['person1']['global_orient'].transpose(-1, -2), (gender_seq_dict['person2']['transl']-gender_seq_dict['person1']['transl']).unsqueeze(-1)).squeeze(-1)
                rel_root_transl['a2b'] = torch.matmul(gender_seq_dict['person2']['global_orient'].transpose(-1, -2), (gender_seq_dict['person1']['transl']-gender_seq_dict['person2']['transl']).unsqueeze(-1)).squeeze(-1)
                
                dists = torch.norm(gender_seq_dict['person1']['joints'].unsqueeze(3)-gender_seq_dict['person2']['joints'].unsqueeze(2), dim=-1)
                rel_mindis['b2a'], _ = dists.min(dim=-1)
                rel_mindis['a2b'], _ = dists.min(dim=-2)
                
                rel_info = {}
                for key in ['b2a', 'a2b']:
                    rel_info[key] = torch.cat([transforms.matrix_to_rotation_6d(rel_global_orient[key]), rel_root_transl[key], rel_mindis[key]], dim=-1)  # [B*num_mp, T, 6+3+22]

        
        feature_dict = {}
        data_batch = {} if self.mode == 'merged' else []
        for person in ['person1', 'person2']:
            feature_dict[person] = self.primitive_utility.calc_features(canonicalized_primitive_dict[person], use_predicted_joints=True)
            feature_dict[person]['transl'] = feature_dict[person]['transl'][:, :-1, :]  # [num_mp, T, 3]
            feature_dict[person]['poses_6d'] = feature_dict[person]['poses_6d'][:, :-1, :]  # [num_mp, T, 66]
            feature_dict[person]['joints'] = feature_dict[person]['joints'][:, :-1, :]  # [num_mp, T, 22 * 3]
        
        if self.mode == 'merged':
            for person in ['person1', 'person2']:
                data_batch[person] = []
                motion_tensor_normalized = self.normalize(self.dict_to_tensor(feature_dict[person]))     # [B*num_mp, T, D]
                motion_tensor_normalized = motion_tensor_normalized.permute(0, 2, 1).unsqueeze(2)           # [B*num_mp, D, 1, T]
                history_mask = torch.zeros_like(motion_tensor_normalized, dtype=torch.bool, device=self.device)
                history_mask[..., :self.cfg.history_length] = True
                history_motion = torch.zeros_like(motion_tensor_normalized, dtype=torch.float32, device=self.device)
                history_motion[..., :self.cfg.history_length] = motion_tensor_normalized[..., :self.cfg.history_length]

                for primitive_idx in range(self.num_primitive):
                    start_idx = primitive_idx 
                    end_idx = primitive_idx + 1
                    primitive_texts = gender_seq_texts[person][start_idx:end_idx]
                    unseen_texts = [text for text in primitive_texts if text not in self.text_embedding_dict[person]]
                    if len(unseen_texts) > 0:
                        self.update_text_embedding_dict(unseen_texts, person, text_sep=self.text_sep, max_segs=self.max_segs)
                    text_embedding = torch.stack([self.text_embedding_dict[person][text] for text in primitive_texts], dim=0)  # [B, 512]
                    if self.text_sep:
                        text_mask = torch.stack([self.text_mask_dict[person][text] for text in primitive_texts], dim=0)  # [B, max_segs]
                    else:
                        text_mask = None
            
                    data_batch[person].append({
                            'texts': primitive_texts,
                            'text_embedding': text_embedding,
                            'text_mask': text_mask, 
                            'gender': [gender_seq_dict[person]['gender']],
                            'betas': gender_seq_dict[person]['betas'][start_idx:end_idx, :-1, :10],
                            'motion_tensor_normalized': motion_tensor_normalized[start_idx:end_idx, ...], # [B, D, 1, T]
                            'history_motion': history_motion[start_idx:end_idx, ...],
                            'history_mask': history_mask[start_idx:end_idx, ...],
                            'transf_rotmat': transf_rotmat[person][start_idx:end_idx, ...],
                            'transf_transl': transf_transl[person][start_idx:end_idx, ...],
                            'history_length': self.cfg.history_length,
                            'future_length': self.cfg.future_length,
                        })
            data_batch['interaction'] = []
            for primitive_idx in range(self.num_primitive):        
                start_idx = primitive_idx
                end_idx = (primitive_idx + 1)
                primitive_texts = gender_seq_texts['interaction'][start_idx:end_idx]
                unseen_texts = [text for text in primitive_texts if text not in self.text_embedding_dict['interaction']]
                if len(unseen_texts) > 0:
                    self.update_text_embedding_dict(unseen_texts, 'interaction', text_sep=self.text_sep, max_segs=self.max_segs)
                text_embedding = torch.stack([self.text_embedding_dict['interaction'][text] for text in primitive_texts], dim=0)  # [B, 512]
                if self.text_sep:
                    text_mask = torch.stack([self.text_mask_dict['interaction'][text] for text in primitive_texts], dim=0)  # [B, max_segs]
                else:
                    text_mask = None
                data_batch['interaction'].append({
                        'texts': primitive_texts,
                        'text_embedding': text_embedding,
                        'text_mask': text_mask, 
                        'rel_pose_b2a': rel_pose['b2a'][start_idx:end_idx],
                        'rel_pose_a2b': rel_pose['a2b'][start_idx:end_idx],
                    })
                if self.use_interaction_model:
                    data_batch['interaction'][-1].update({
                        'rel_info_b2a': rel_info['b2a'][start_idx:end_idx, self.cfg.history_length:-1].reshape(1, -1),
                        'rel_info_a2b': rel_info['a2b'][start_idx:end_idx, self.cfg.history_length:-1].reshape(1, -1),
                    })
            return data_batch
        else:
            for person in ['person1', 'person2']:
                motion_tensor_normalized = self.normalize(self.dict_to_tensor(feature_dict[person]))     # [num_mp, T, D]
                motion_tensor_normalized = motion_tensor_normalized.permute(0, 2, 1).unsqueeze(2)           # [num_mp, D, 1, T]
                history_mask = torch.zeros_like(motion_tensor_normalized, dtype=torch.bool, device=self.device)
                history_mask[..., :self.cfg.history_length] = True
                history_motion = torch.zeros_like(motion_tensor_normalized, dtype=torch.float32, device=self.device)
                history_motion[..., :self.cfg.history_length] = motion_tensor_normalized[..., :self.cfg.history_length]

                for primitive_idx in range(self.num_primitive):
                    start_idx = primitive_idx
                    end_idx = primitive_idx + 1
                    primitive_texts = gender_seq_texts[person][start_idx:end_idx]
                    unseen_texts = [text for text in primitive_texts if text not in self.text_embedding_dict[person]]
                    if len(unseen_texts) > 0:
                        new_text_embeddings = encode_text(self.clip_model, unseen_texts)
                        for idx, text in enumerate(unseen_texts):
                            self.text_embedding_dict[person][text] = new_text_embeddings[idx]
                    text_embedding = torch.stack([self.text_embedding_dict[person][text] for text in primitive_texts], dim=0)  # [B, 512]
                    data_batch.append(
                        {
                            'texts': primitive_texts,
                            'text_embedding': text_embedding,
                            'gender': [gender_seq_dict[person]['gender']],
                            'betas': gender_seq_dict[person]['betas'][start_idx:end_idx, :-1, :10],
                            'motion_tensor_normalized': motion_tensor_normalized[start_idx:end_idx, ...], # [1, D, 1, T]
                            'history_motion': history_motion[start_idx:end_idx, ...],
                            'history_mask': history_mask[start_idx:end_idx, ...],
                            'history_length': self.cfg.history_length,
                            'future_length': self.cfg.future_length,
                        }
                    )
            if random.random() < 0.5:
                return data_batch[:self.num_primitive]
            else:
                return data_batch[self.num_primitive:]

# dataset = InterHumanDatasetV3(enforce_gender='male',
#                             enforce_zero_beta=1,
#                             device='cuda:7',
#                             mode='merged',
#                             text_encoder='clip',
#                             text_sep=False,
#                             split='train',
#                             use_interaction_model=True,)
# # dataset.calc_mean_std()

# batch_test = dataset.get_batch(batch_size=8)
# sample = dataset.get_item(0)


class InterHumanDatasetV4(WeightedPrimitiveSequenceDataset):
    def __init__(self, dataset_name='interhuman',
                 dataset_path='./data/InterHuman/seq_data_single_interaction_zero_male_fps20',
                 cfg_path='./config_files/config_hydra/motion_primitive/hml_mp_h2_f8_r4.yaml',
                 split="train",
                 device='cuda',
                 weight_scheme='uniform',
                 prob_static=0.0,
                 enforce_gender=None,
                 enforce_zero_beta=None,
                 load_data=True,
                 text_tolerance=0.0,
                 body_type='smplh',
                 seed_only=False,
                 use_frame_weights=True,
                 mode='merged', # 'sep' or 'merged'
                 text_sep = False,
                 max_segs = 20,
                 **kwargs):
        self.dataset_name = dataset_name
        self.dataset_path = dataset_path
        self.split = split
        self.device = device
        self.weight_scheme = weight_scheme
        self.prob_static = prob_static
        self.enforce_gender = enforce_gender
        self.enforce_zero_beta = enforce_zero_beta
        self.text_tolerance = text_tolerance
        self.seed_only = seed_only
        self.mode = mode
        self.text_sep = text_sep
        self.max_segs = max_segs
        self.padding = kwargs.get('padding', False)
        self.use_interaction_model = kwargs.get('use_interaction_model', False)
        self.key_list = ['person1', 'person2', 'interaction'] if self.mode=='merged' else ['person1', 'person2']
        print('enforce_gender: ', enforce_gender)
        print('enforce_zero_beta: ', enforce_zero_beta)

        motion_repr = {'transl': 3,
                       'poses_6d': 22 * 6,
                       'transl_delta': 3,
                       'global_orient_delta_6d': 6,
                       'joints': 22 * 3,
                       'joints_delta': 22 * 3,
                       }
        self.primitive_utility = PrimitiveUtility(device=self.device, body_type=body_type, motion_repr=motion_repr)
        self.motion_repr = self.primitive_utility.motion_repr

        # cfg_path = Path(dataset_path, 'config.yaml')
        with open(cfg_path, 'r') as f:
            self.cfg = OmegaConf.load(f)
        self.target_fps = self.cfg.fps
        # self.downsample_rate = 120 // self.target_fps
        self.history_length = self.cfg.history_length
        self.future_length = self.cfg.future_length
        self.primitive_length = self.history_length + self.future_length
        self.num_primitive = self.cfg.num_primitive
        if seed_only:
            assert self.num_primitive == 1
        self.seq_length = self.history_length + self.future_length * self.num_primitive + 1

        if load_data:
            with open(pjoin(dataset_path, f'{split}.pkl'), 'rb') as f:
                dataset = pickle.load(f)
            if not self.padding:
                dataset = [data for data in dataset if len(data['motion_p1']['trans']) >= self.seq_length]
            
            elements_to_remove = ['7220', '7221', '6028', '7543', '6940', '4434', '7561', '4385']
            dataset = [data for data in dataset if data['seq_name'] not in elements_to_remove]

            for data in dataset:
                # assert self.enforce_gender == data['motion_p1']['gender']
                # assert self.enforce_gender == data['motion_p2']['gender']
                # assert self.enforce_zero_beta
                if self.padding:
                    T = data['motion_p1']['trans'].shape[0]
                    if T < self.seq_length:
                        pad_len = self.seq_length - T
                        for person in ['motion_p1', 'motion_p2']:
                            for key in ['trans', 'poses', 'joints']:
                                last_frame = data[person][key][-1:]
                                padding = np.repeat(last_frame, pad_len, axis=0)
                                data[person][key] = np.concatenate([
                                    data[person][key],
                                    padding
                                ], axis=0)

                            # padding_mask
                            data[person]['padding_mask'] = np.concatenate([
                                np.zeros(T, dtype=np.bool_),
                                np.ones(pad_len, dtype=np.bool_)
                            ], axis=0)
                    else:
                        for person in ['motion_p1', 'motion_p2']:
                            data[person]['padding_mask'] = np.zeros(T, dtype=np.bool_)
                        
                def convert_motion(motion, gender, enforce_zero_beta):
                    betas = torch.from_numpy(motion['betas'].astype(np.float32))
                    if enforce_zero_beta:
                        betas = torch.zeros_like(betas)
                    poses = torch.from_numpy(motion['poses'].astype(np.float32))
                    transl = torch.from_numpy(motion['trans'].astype(np.float32))
                    global_orient = transforms.axis_angle_to_matrix(poses[:, :3])                       # [T, 3, 3]
                    body_pose = transforms.axis_angle_to_matrix(poses[:, 3:66].reshape(-1, 21, 3))      # [T, 21, 3, 3]
                    pelvis_delta = torch.from_numpy(motion['pelvis_delta'].astype(np.float32))          # [3]
                    joints = torch.from_numpy(motion['joints'].astype(np.float32))                      # [T, 22, 3]
                    result = {
                        'gender': gender,
                        'betas': betas,
                        'transl': transl,
                        'global_orient': global_orient,
                        'body_pose': body_pose,
                        'pelvis_delta': pelvis_delta,
                        'joints': joints,
                    }
                    if self.padding:
                        result['padding_mask'] = motion['padding_mask']
                        
                    return result
                gender_p1 = self.enforce_gender if self.enforce_gender is not None else data['motion_p1']['gender']
                gender_p2 = self.enforce_gender if self.enforce_gender is not None else data['motion_p2']['gender']
                data['motion_p1'] = convert_motion(data['motion_p1'], gender_p1, self.enforce_zero_beta)
                data['motion_p2'] = convert_motion(data['motion_p2'], gender_p2, self.enforce_zero_beta)
            
            print('num of sequences: ', len(dataset))
            
            # assign sampling weights to each sequence
            for data in dataset:
                if 'uniform' in weight_scheme:
                    data['weight'] = 1.0
                elif 'length' in weight_scheme:
                    data['weight'] = len(data['motion_p1']['trans'])
            print('finish first assigning seq weights')

            # overfit using one sequence
            # if 'overfit' in weight_scheme:
            #     seq_id = int(weight_scheme.split('overfit:')[-1].split('_')[0])
            #     for idx, data in enumerate(dataset):
            #         if idx == seq_id:
            #             data['weight'] = 1.0
            #         else:
            #             data['weight'] = 0.0

            seq_weights = np.array([data['weight'] for data in dataset])
            seq_weights = seq_weights / seq_weights.sum()

            self.dataset = dataset
            self.seq_weights = seq_weights

        # load or calc mean and std
        self.tensor_mean_device_dict = {}
        file_name = f'mean_std_h{self.history_length}_f{self.future_length}'
        
        mean_std_path = Path(dataset_path, f'{file_name}.pkl')
        if self.padding:
            mean_std_path = Path(dataset_path, f'{file_name}_padding.pkl')
        if self.use_interaction_model:
            mean_std_interaction_path = Path(dataset_path, f'{file_name}_interaction.pkl')
        if mean_std_path.exists() and (not self.use_interaction_model or mean_std_interaction_path.exists()):
            print(f'loading mean and std from {mean_std_path}')
            with open(mean_std_path, 'rb') as f:
                self.tensor_mean, self.tensor_std = pickle.load(f)  # [1, 1, D]

            if self.use_interaction_model:
                print(f'loading interaction mean and std from {mean_std_interaction_path}')
                with open(mean_std_interaction_path, 'rb') as f:
                    self.rel_mean, self.rel_std = pickle.load(f)  # [1, 1, D]
        else:
            assert self.split == 'train'
            print('calculating mean and std using train split')
            result = self.calc_mean_std()

            if self.use_interaction_model:
                self.tensor_mean, self.tensor_std, self.rel_mean, self.rel_std = result
            else:
                self.tensor_mean, self.tensor_std = result
                self.rel_mean, self.rel_std = None, None

            with open(mean_std_path, 'wb') as f:
                pickle.dump((self.tensor_mean.detach().cpu(), self.tensor_std.detach().cpu()), f)

            if self.use_interaction_model:
                with open(mean_std_interaction_path, 'wb') as f:
                    pickle.dump((self.rel_mean.detach().cpu(), self.rel_std.detach().cpu()), f)
        

        # load clip model, get train text embeddings
        self.clip_model = load_and_freeze_clip(clip_version='ViT-B/32', device=self.device)
        self.embedding_path = {}
        embedding_path = {}
        for key_type in self.key_list:
            if text_sep:
                self.embedding_path[key_type] = embedding_path[key_type] = Path(dataset_path, f'{split}_{key_type}_text_embedding_dict_textsep.pkl')
            else:
                self.embedding_path[key_type] = embedding_path[key_type] = Path(dataset_path, f'{split}_{key_type}_text_embedding_dict.pkl')
        self.text_embedding_dict = {}
        if text_sep:
            self.text_mask_dict = {}
        
        for key_type in self.key_list:
            if embedding_path[key_type].exists():
                print(f"Loading text_{key_type} embeddings from {embedding_path[key_type]}!")
                with open(embedding_path[key_type], 'rb') as f:
                    self.text_embedding_dict[key_type] = pickle.load(f)
                if text_sep:
                    with open(Path(str(embedding_path[key_type]).replace('text_embedding_dict', 'text_mask_dict')), 'rb') as f:
                        self.text_mask_dict[key_type] = pickle.load(f)
            else:
                print('Calculating text embeddings')
                raw_texts = []
                for data in self.dataset:
                    if f'frame_labels_{key_type}' in data:
                        raw_texts.extend([seg['proc_label'] for seg in data['frame_labels_' + key_type]])

                raw_texts = list(set(raw_texts))
                num_texts = len(raw_texts)
                print(f'num of unique texts_{key_type}: ', len(raw_texts))
                    
                # get text embeddings by batch
                text_embeddings = []
                text_mask = []
                batch_start_idx = 0
                while batch_start_idx < num_texts:
                    batch_end_idx = min(batch_start_idx + 256, num_texts)
                    text_embeddings_temp = encode_text(self.clip_model, raw_texts[batch_start_idx:batch_end_idx], text_sep=text_sep, max_segs=max_segs)
                    if text_sep:
                        text_embeddings.append(text_embeddings_temp[0])
                        text_mask.append(text_embeddings_temp[1])
                    else:
                        text_embeddings.append(text_embeddings_temp)
                    batch_start_idx = batch_end_idx
                text_embeddings = torch.cat(text_embeddings, dim=0).detach().cpu().numpy()
            
                self.text_embedding_dict[key_type] = {raw_texts[idx]: text_embeddings[idx] for idx in range(num_texts)}
                if text_sep:
                    self.text_embedding_dict[key_type][''] = np.zeros((self.max_segs, 512)).astype(np.float32)
                else:
                    self.text_embedding_dict[key_type][''] = np.zeros(512).astype(np.float32)  # for empty text have zero embedding, compatible with mdm text masking
                with open(embedding_path[key_type], 'wb') as f:
                    pickle.dump(self.text_embedding_dict[key_type], f)
                if text_sep:
                    text_mask = torch.cat(text_mask, dim=0).detach().cpu().numpy()
                    self.text_mask_dict[key_type] = {raw_texts[idx]: text_mask[idx] for idx in range(num_texts)}
                    self.text_mask_dict[key_type][''] = np.zeros(max_segs).astype(np.bool_)  # for empty text have zero embedding, compatible with mdm text masking
                    with open(Path(str(embedding_path[key_type]).replace('text_embedding_dict', 'text_mask_dict')), 'wb') as f:
                        pickle.dump(self.text_mask_dict[key_type], f)
            
            for key in self.text_embedding_dict[key_type]:
                self.text_embedding_dict[key_type][key] = torch.from_numpy(self.text_embedding_dict[key_type][key]).to(dtype=torch.float32, device=self.device)
                if text_sep:
                    self.text_mask_dict[key_type][key] = torch.from_numpy(self.text_mask_dict[key_type][key]).to(dtype=torch.bool, device=self.device)

    def update_text_embedding_dict(self, new_texts, key_type, text_sep=False, max_segs=20):
        new_text_embeddings = encode_text(self.clip_model, new_texts, text_sep=text_sep, max_segs=max_segs)
        for idx, text in enumerate(new_texts):
            if text_sep:
                self.text_embedding_dict[key_type][text] = new_text_embeddings[0][idx]
                self.text_mask_dict[key_type][text] = new_text_embeddings[1][idx]
            else:
                self.text_embedding_dict[key_type][text] = new_text_embeddings[idx]

    def calc_mean_std(self, batch_size=512):
        if self.future_length == 1:
            batch_size = min(batch_size, 64)
        all_mp_data = []
        if self.use_interaction_model:
            all_rel_info = []
        for seq_data in self.dataset:
            motion_data_p1 = seq_data['motion_p1']
            num_frames = motion_data_p1['transl'].shape[0]
            primitive_data_list = []
            for start_frame in range(0, num_frames - self.primitive_length, self.future_length):
                end_frame = start_frame + self.primitive_length
                primitive_data_list.append(self.get_primitive(seq_data, start_frame, end_frame, skip_text=True))
                
            primitive_dict = {}
            for person in ['person1', 'person2']:
                primitive_dict[person] = {}
                primitive_dict[person]['gender'] = {primitive_data_list[0]['primitive_dict'][person]['gender']}
                for key in ['betas', 'transl', 'global_orient', 'body_pose', 'transf_rotmat', 'transf_transl', 'pelvis_delta', 'joints']:
                    primitive_dict[person][key] = torch.cat([data['primitive_dict'][person][key] for data in primitive_data_list], dim=0)
                if self.padding:
                    primitive_dict[person]['primitive_padding_mask'] = torch.cat([data['primitive_dict'][person]['primitive_padding_mask'] for data in primitive_data_list], dim=0)
                primitive_dict[person] = tensor_dict_to_device(primitive_dict[person], self.device)

            # split primitive_dict into batches
            batch_start_idx = 0
            while batch_start_idx < len(primitive_dict['person1']['transl']):
                batch_primitive_dict = {}
                canonicalized_primitive_dict = {}
                batch_end_idx = min(batch_start_idx + batch_size, len(primitive_dict['person1']['transl']))
                for person in ['person1', 'person2']:
                    batch_primitive_dict[person] = {}
                    batch_primitive_dict[person] = {key: primitive_dict[person][key][batch_start_idx:batch_end_idx] for key in ['betas', 'transl', 'global_orient', 'body_pose', 'transf_rotmat', 'transf_transl', 'pelvis_delta', 'joints']}
                    batch_primitive_dict[person]['gender'] = primitive_dict[person]['gender']
                    # transf_rotmat[person], transf_transl[person], canonicalized_primitive_dict[person] = self.primitive_utility.canonicalize(batch_primitive_dict[person], use_predicted_joints=True)
                    _, _, canonicalized_primitive_dict[person] = self.primitive_utility.canonicalize(copy.deepcopy(batch_primitive_dict[person]), use_predicted_joints=True)

                feature_dict = {}
                for person in ['person1', 'person2']:
                    feature_dict[person] = self.primitive_utility.calc_features(canonicalized_primitive_dict[person], use_predicted_joints=True)
                    feature_dict[person]['transl'] = feature_dict[person]['transl'][:, :-1, :]          # [num_primitive, T, 3]
                    feature_dict[person]['poses_6d'] = feature_dict[person]['poses_6d'][:, :-1, :]      # [num_primitive, T, 66]
                    feature_dict[person]['joints'] = feature_dict[person]['joints'][:, :-1, :]          # [num_primitive, T, 22 * 3]
                    # feature_dict[person]['transf_rotmat_6d'] = transforms.matrix_to_rotation_6d(transf_rotmat[person])
                    # feature_dict[person]['transf_transl'] = transf_transl[person]
                    motion_tensor = self.dict_to_tensor(feature_dict[person]).detach().cpu()    # [num_primitive, T, D]
                    if self.padding:
                        mask_slice = primitive_dict[person]['primitive_padding_mask'][batch_start_idx:batch_end_idx, -1]  # [B]
                        valid_indices = torch.nonzero(~mask_slice, as_tuple=True)[0].detach().cpu()  # select valid
                        motion_tensor = motion_tensor[valid_indices]
                    all_mp_data.append(motion_tensor)                                           # [num_primitive, T, D]

                if self.use_interaction_model:
                    joints_p1 = batch_primitive_dict['person1']['joints'].reshape(-1, batch_primitive_dict['person1']['joints'].shape[1], 22, 3)  # [B, T+1, 22, 3]
                    joints_p2 = batch_primitive_dict['person2']['joints'].reshape(-1, batch_primitive_dict['person2']['joints'].shape[1], 22, 3)
                    global_orient_p1 = batch_primitive_dict['person1']['global_orient']  # [B, T, 3, 3]
                    global_orient_p2 = batch_primitive_dict['person2']['global_orient']
                    transl_p1 = batch_primitive_dict['person1']['transl']  # [B, T, 3]
                    transl_p2 = batch_primitive_dict['person2']['transl']

                    rel_global_orient = {
                        'b2a': global_orient_p1.transpose(-1, -2) @ global_orient_p2,
                        'a2b': global_orient_p2.transpose(-1, -2) @ global_orient_p1,
                    }
                    rel_root_transl = {
                        'b2a': torch.matmul(global_orient_p1.transpose(-1, -2), (transl_p2 - transl_p1).unsqueeze(-1)).squeeze(-1),
                        'a2b': torch.matmul(global_orient_p2.transpose(-1, -2), (transl_p1 - transl_p2).unsqueeze(-1)).squeeze(-1),
                    }

                    dists = torch.norm(joints_p1.unsqueeze(3) - joints_p2.unsqueeze(2), dim=-1)  # [B, T+1, 22, 22]
                    rel_mindis = {
                        'b2a': dists.min(dim=-1).values,  # [B, T+1, 22]
                        'a2b': dists.min(dim=-2).values,
                    }

                    for key in ['b2a', 'a2b']:
                        rot_6d = transforms.matrix_to_rotation_6d(rel_global_orient[key])  # [B, T+1, 6]
                        rel_info = torch.cat([rot_6d, rel_root_transl[key], rel_mindis[key]], dim=-1)  # [B, T+1, 6+3+22]
                        rel_info = rel_info[:, self.cfg.history_length:-1, :].detach().cpu()  # [B, T-(1+history_length), D]
                        if self.padding:
                            rel_info = rel_info[valid_indices]
                        all_rel_info.append(rel_info)

                batch_start_idx = batch_end_idx

        all_mp_data = torch.cat(all_mp_data, dim=0)                 # [2*N, T, D]
        tensor_mean = all_mp_data.mean(dim=[0, 1], keepdim=True)    # [1, 1, D]
        tensor_std = all_mp_data.std(dim=[0, 1], keepdim=True)      # [1, 1, D]
        if self.use_interaction_model and len(all_rel_info) > 0:
            all_rel_info = torch.cat(all_rel_info, dim=0)  # [2N, (T-(1+history_length)), D]
            rel_mean = all_rel_info.mean(dim=[0, 1], keepdim=True)
            rel_std = all_rel_info.std(dim=[0, 1], keepdim=True)
        if self.use_interaction_model:
            return tensor_mean.to(self.device), tensor_std.to(self.device), rel_mean.to(self.device), rel_std.to(self.device)
        return tensor_mean.to(self.device), tensor_std.to(self.device)

    def get_primitive(self, seq_data, start_frame, end_frame, skip_text=False):
        """end_frame included"""
        primitive_dict = {}
        for person, motion_data in zip(['person1', 'person2'], [seq_data['motion_p1'], seq_data['motion_p2']]):
            primitive_dict[person] = {
                'gender': motion_data['gender'],
                'betas': motion_data['betas'].expand(1, self.primitive_length + 1, 10),
                'transl': motion_data['transl'][start_frame:end_frame + 1].unsqueeze(0),
                'global_orient': motion_data['global_orient'][start_frame:end_frame + 1].unsqueeze(0),
                'body_pose': motion_data['body_pose'][start_frame:end_frame + 1].unsqueeze(0),
                'pelvis_delta': motion_data['pelvis_delta'].unsqueeze(0),
                'joints': motion_data['joints'][start_frame:end_frame + 1].unsqueeze(0),
                'transf_rotmat': torch.eye(3).unsqueeze(0),
                'transf_transl': torch.zeros(1, 1, 3),
            }
            if self.padding:
                padding_mask_full = seq_data[f'motion_p{person[-1]}']['padding_mask'][start_frame:end_frame + 1]  # shape [T+1]
                history_mask = torch.tensor(padding_mask_full[:self.history_length], dtype=torch.bool)
                future_mask = padding_mask_full[self.history_length:-1]
                future_flag = torch.tensor(future_mask.any(), dtype=torch.bool)
                primitive_dict[person]['primitive_padding_mask'] = torch.cat([history_mask, future_flag.unsqueeze(0)], dim=0).unsqueeze(0) # (1, history_length + 1)

        texts = {key: [] for key in self.key_list}
        for key_type in self.key_list:
            if not skip_text and f'frame_labels_{key_type}' in seq_data:
                future_start = (start_frame + self.history_length) / self.target_fps
                future_end = (start_frame + self.history_length + self.future_length - 1) / self.target_fps
                # print('text tolerance: ', self.text_tolerance)
                for seg in seq_data[f'frame_labels_{key_type}']:
                    if have_overlap([seg['start_t'], seg['end_t']], [future_start - self.text_tolerance, future_end + self.text_tolerance]):
                        texts[key_type].append(seg['proc_label'])

        output = {}
        for key_type in self.key_list:
            output['text_'+key_type] = random.choice(texts[key_type]) if len(texts[key_type]) > 0 else ''
        output['primitive_dict'] = primitive_dict
        return output

    def get_rel_mean_std_by_device(self, device):
        if not hasattr(self, 'rel_mean_device_dict'):
            self.rel_mean_device_dict = {}

        if device not in self.rel_mean_device_dict:
            assert self.rel_mean is not None and self.rel_std is not None, "rel_mean/std must be computed before normalization."
            self.rel_mean_device_dict[device] = (
                self.rel_mean.to(device=device),
                self.rel_std.to(device=device)
            )
        return self.rel_mean_device_dict[device]

    def normalize_rel_info(self, rel_info: torch.Tensor) -> torch.Tensor:
        """
        Standardize interaction feature tensor using rel_mean / rel_std
        rel_info: Tensor of shape [B, D] or [B, T, D]
        """
        rel_mean, rel_std = self.get_rel_mean_std_by_device(rel_info.device)
        return (rel_info - rel_mean) / rel_std

    def get_batch(self, batch_size=8):
        self.time = time.time()
        seq_list = []
        batch_idx = self.get_batch_idx(batch_size)
        add_key_list = ['texts', 'gender']
        cat_key_list = ['betas', 'motion_tensor_normalized', 'history_motion', 'history_mask', 'text_embedding', 'transf_rotmat', 'transf_transl']
        if self.padding:
            cat_key_list.append('primitive_padding_mask')
        if self.text_sep:
            cat_key_list.append('text_mask')
        
        for seq_idx in batch_idx:
            seq_data = self.dataset[seq_idx]
            num_frames = len(seq_data['motion_p1']['transl'])
            if 'text' in self.weight_scheme:
                start_frame = random.choices(range(num_frames - self.seq_length + 1), weights=seq_data['frame_weights'], k=1)[0]
            else:
                start_frame = random.randint(0, num_frames - self.seq_length)  # [0, num_frames - seq_length], right end inclusive
            primitive_data_list = []
            for frame_idx in range(start_frame, start_frame + self.seq_length - self.primitive_length, self.future_length):
                primitive_data = self.get_primitive(seq_data, frame_idx, frame_idx + self.primitive_length)
                primitive_data_list.append(primitive_data)
            seq_list.append(primitive_data_list)

        # sort batch by gender
        batch = None
        for gender in ['female', 'male']:
            gender_idx = [idx for idx in range(len(seq_list)) if seq_list[idx][0]['primitive_dict']['person1']['gender'] == gender]
            if len(gender_idx) == 0:
                continue
            gender_seq_list = [seq_list[i] for i in gender_idx]
            gender_batch_size = len(gender_idx)
            gender_batch = {} if self.mode == 'merged' else []
            
            gender_seq_texts = {key_type: None for key_type in self.key_list}
            gender_seq_dict = None
            for primitive_idx in range(self.num_primitive):
                primitive_dict = {}
                for person in ['person1', 'person2']:
                    primitive_dict[person] = {}
                    primitive_dict[person]['gender'] = gender
                    for key in ['betas', 'transl', 'global_orient', 'body_pose', 'transf_rotmat', 'transf_transl', 'pelvis_delta', 'joints']:
                        primitive_dict[person][key] = torch.cat([mp_seq[primitive_idx]['primitive_dict'][person][key] for mp_seq in gender_seq_list], dim=0)
                    if self.padding:
                        primitive_dict[person]['primitive_padding_mask'] = torch.cat([mp_seq[primitive_idx]['primitive_dict'][person]['primitive_padding_mask'] for mp_seq in gender_seq_list], dim=0)
                primitive_texts = {}
                for key_type in self.key_list:
                    primitive_texts[key_type] = [mp_seq[primitive_idx]['text_'+key_type] for mp_seq in gender_seq_list]
                    gender_seq_texts[key_type] = primitive_texts[key_type] if gender_seq_texts[key_type] is None else gender_seq_texts[key_type] + primitive_texts[key_type]
                
                if gender_seq_dict is None:
                    gender_seq_dict = primitive_dict
                else:
                    for person in ['person1', 'person2']:
                        for key in ['betas', 'transl', 'global_orient', 'body_pose', 'transf_rotmat', 'transf_transl', 'pelvis_delta', 'joints']:
                            gender_seq_dict[person][key] = torch.cat([gender_seq_dict[person][key], primitive_dict[person][key]], dim=0)
                        if self.padding:
                            gender_seq_dict[person]['primitive_padding_mask'] = torch.cat([gender_seq_dict[person]['primitive_padding_mask'], 
                                                                                           primitive_dict[person]['primitive_padding_mask']], dim=0)

            canonicalized_primitive_dict = {}
            transf_rotmat, transf_transl = {}, {}
            for person in ['person1', 'person2']:
                gender_seq_dict[person] = tensor_dict_to_device(gender_seq_dict[person], self.device)
                transf_rotmat[person], transf_transl[person], canonicalized_primitive_dict[person] = self.primitive_utility.canonicalize(copy.deepcopy(gender_seq_dict[person]), use_predicted_joints=True)            
            
            if self.mode == 'merged':
                # rel_pose
                rel_rotmat, rel_transl, rel_pose = {}, {}, {}
                rel_rotmat['b2a'], rel_transl['b2a'] = self.primitive_utility.compute_rel_transform_B_in_A(transf_rotmat['person1'], transf_transl['person1'], transf_rotmat['person2'], transf_transl['person2'])
                rel_rotmat['a2b'], rel_transl['a2b'] = self.primitive_utility.compute_rel_transform_B_in_A(transf_rotmat['person2'], transf_transl['person2'], transf_rotmat['person1'], transf_transl['person1'])
                for rel in ['b2a', 'a2b']:
                    rel_rotmat[rel] = transforms.matrix_to_rotation_6d(rel_rotmat[rel])
                    rel_pose[rel] = torch.cat([rel_rotmat[rel], rel_transl[rel].squeeze(1)], dim=-1)  # [B*num_mp, 6+3]
            
                if self.use_interaction_model:
                    # reltive transition, relative distance
                    rel_global_orient, rel_root_transl, rel_mindis = {}, {}, {}
                    rel_global_orient['b2a'] = gender_seq_dict['person1']['global_orient'].transpose(-1, -2) @ gender_seq_dict['person2']['global_orient']
                    rel_global_orient['a2b'] = gender_seq_dict['person2']['global_orient'].transpose(-1, -2) @ gender_seq_dict['person1']['global_orient']
                    rel_root_transl['b2a'] = torch.matmul(gender_seq_dict['person1']['global_orient'].transpose(-1, -2), (gender_seq_dict['person2']['transl']-gender_seq_dict['person1']['transl']).unsqueeze(-1)).squeeze(-1)
                    rel_root_transl['a2b'] = torch.matmul(gender_seq_dict['person2']['global_orient'].transpose(-1, -2), (gender_seq_dict['person1']['transl']-gender_seq_dict['person2']['transl']).unsqueeze(-1)).squeeze(-1)
                    
                    dists = torch.norm(gender_seq_dict['person1']['joints'].unsqueeze(3)-gender_seq_dict['person2']['joints'].unsqueeze(2), dim=-1)
                    rel_mindis['b2a'], _ = dists.min(dim=-1)
                    rel_mindis['a2b'], _ = dists.min(dim=-2)
                    
                    rel_info = {}
                    for key in ['b2a', 'a2b']:
                        rel_info[key] = torch.cat([transforms.matrix_to_rotation_6d(rel_global_orient[key]), rel_root_transl[key], rel_mindis[key]], dim=-1)  # [B*num_mp, T, 6+3+22]
                
            # calc features
            feature_dict = {}
            for person in ['person1', 'person2']:
                feature_dict[person] = self.primitive_utility.calc_features(canonicalized_primitive_dict[person], use_predicted_joints=True)
                feature_dict[person]['transl'] = feature_dict[person]['transl'][:, :-1, :]  # [B*num_mp, T, 3]
                feature_dict[person]['poses_6d'] = feature_dict[person]['poses_6d'][:, :-1, :]  # [B*num_mp, T, 66]
                feature_dict[person]['joints'] = feature_dict[person]['joints'][:, :-1, :]  # [B*num_mp, T, 22 * 3]
            
            
            if self.mode == 'merged':
                for person in ['person1', 'person2']:
                    gender_batch[person] = []
                    motion_tensor_normalized = self.normalize(self.dict_to_tensor(feature_dict[person]))     # [B*num_mp, T, D]
                    motion_tensor_normalized = motion_tensor_normalized.permute(0, 2, 1).unsqueeze(2)           # [B*num_mp, D, 1, T]
                    history_mask = torch.zeros_like(motion_tensor_normalized, dtype=torch.bool, device=self.device)
                    history_mask[..., :self.cfg.history_length] = True
                    history_motion = torch.zeros_like(motion_tensor_normalized, dtype=torch.float32, device=self.device)
                    history_motion[..., :self.cfg.history_length] = motion_tensor_normalized[..., :self.cfg.history_length]

                    for primitive_idx in range(self.num_primitive):
                        start_idx = primitive_idx * gender_batch_size
                        end_idx = (primitive_idx + 1) * gender_batch_size
                        primitive_texts = gender_seq_texts[person][start_idx:end_idx]
                        unseen_texts = [text for text in primitive_texts if text not in self.text_embedding_dict[person]]
                        if len(unseen_texts) > 0:
                            self.update_text_embedding_dict(unseen_texts, person, text_sep=self.text_sep, max_segs=self.max_segs)
                        text_embedding = torch.stack([self.text_embedding_dict[person][text] for text in primitive_texts], dim=0)  # [B, 512]
                        if self.text_sep:
                            text_mask = torch.stack([self.text_mask_dict[person][text] for text in primitive_texts], dim=0)  # [B, max_segs]
                        else:
                            text_mask = None
               
                        gender_batch[person].append({
                                'texts': primitive_texts,
                                'text_embedding': text_embedding,
                                'text_mask': text_mask, 
                                'gender': [gender_seq_dict[person]['gender']] * gender_batch_size,
                                'betas': gender_seq_dict[person]['betas'][start_idx:end_idx, :-1, :10],
                                'motion_tensor_normalized': motion_tensor_normalized[start_idx:end_idx, ...], # [B, D, 1, T]
                                'history_motion': history_motion[start_idx:end_idx, ...],
                                'history_mask': history_mask[start_idx:end_idx, ...],
                                'transf_rotmat': transf_rotmat[person][start_idx:end_idx, ...],
                                'transf_transl': transf_transl[person][start_idx:end_idx, ...],
                                'history_length': self.cfg.history_length,
                                'future_length': self.cfg.future_length,
                            })
                        if self.padding:
                            gender_batch[person][-1]['primitive_padding_mask'] = gender_seq_dict[person]['primitive_padding_mask'][start_idx:end_idx, ...]
                gender_batch['interaction'] = []
                for primitive_idx in range(self.num_primitive):        
                    start_idx = primitive_idx * gender_batch_size
                    end_idx = (primitive_idx + 1) * gender_batch_size
                    primitive_texts = gender_seq_texts['interaction'][start_idx:end_idx]
                    unseen_texts = [text for text in primitive_texts if text not in self.text_embedding_dict['interaction']]
                    if len(unseen_texts) > 0:
                        self.update_text_embedding_dict(unseen_texts, 'interaction', text_sep=self.text_sep, max_segs=self.max_segs)
                    text_embedding = torch.stack([self.text_embedding_dict['interaction'][text] for text in primitive_texts], dim=0)  # [B, 512]
                    if self.text_sep:
                        text_mask = torch.stack([self.text_mask_dict['interaction'][text] for text in primitive_texts], dim=0)  # [B, max_segs]
                    else:
                        text_mask = None
                    gender_batch['interaction'].append({
                            'texts': primitive_texts,
                            'text_embedding': text_embedding,
                            'text_mask': text_mask, 
                            'rel_pose_b2a': rel_pose['b2a'][start_idx:end_idx],
                            'rel_pose_a2b': rel_pose['a2b'][start_idx:end_idx],
                        })
                    if self.use_interaction_model:
                        gender_batch['interaction'][-1].update({
                            'rel_info_b2a': self.normalize_rel_info(rel_info['b2a'][start_idx:end_idx, self.cfg.history_length:-1]),
                            'rel_info_a2b': self.normalize_rel_info(rel_info['a2b'][start_idx:end_idx, self.cfg.history_length:-1]),
                        })
                    
            elif self.mode == 'sep':
                for person in ['person1', 'person2']:
                    motion_tensor_normalized = self.normalize(self.dict_to_tensor(feature_dict[person]))     # [B*num_mp, T, D]
                    motion_tensor_normalized = motion_tensor_normalized.permute(0, 2, 1).unsqueeze(2)           # [B*num_mp, D, 1, T]
                    history_mask = torch.zeros_like(motion_tensor_normalized, dtype=torch.bool, device=self.device)
                    history_mask[..., :self.cfg.history_length] = True
                    history_motion = torch.zeros_like(motion_tensor_normalized, dtype=torch.float32, device=self.device)
                    history_motion[..., :self.cfg.history_length] = motion_tensor_normalized[..., :self.cfg.history_length]

                    for primitive_idx in range(self.num_primitive):
                        start_idx = primitive_idx * gender_batch_size
                        end_idx = (primitive_idx + 1) * gender_batch_size
                        primitive_texts = gender_seq_texts[person][start_idx:end_idx]
                        unseen_texts = [text for text in primitive_texts if text not in self.text_embedding_dict[person]]
                        if len(unseen_texts) > 0:
                            self.update_text_embedding_dict(unseen_texts, person, text_sep=self.text_sep, max_segs=self.max_segs)
                            # new_text_embeddings = encode_text(self.clip_model, unseen_texts)
                            # for idx, text in enumerate(unseen_texts):
                            #     self.text_embedding_dict[person][text] = new_text_embeddings[idx]
                        text_embedding = torch.stack([self.text_embedding_dict[person][text] for text in primitive_texts], dim=0)  # [B, 512]
                        if self.text_sep:
                            text_mask = torch.stack([self.text_mask_dict[person][text] for text in primitive_texts], dim=0)  # [B, max_segs]
                        else:
                            text_mask = None
                        gender_batch.append(
                            {
                                'texts': primitive_texts,
                                'text_embedding': text_embedding,
                                'text_mask': text_mask, 
                                'gender': [gender_seq_dict[person]['gender']] * gender_batch_size,
                                'betas': gender_seq_dict[person]['betas'][start_idx:end_idx, :-1, :10],
                                'motion_tensor_normalized': motion_tensor_normalized[start_idx:end_idx, ...], # [B, D, 1, T]
                                'history_motion': history_motion[start_idx:end_idx, ...],
                                'history_mask': history_mask[start_idx:end_idx, ...],
                                'transf_rotmat': transf_rotmat[person][start_idx:end_idx, ...],
                                'transf_transl': transf_transl[person][start_idx:end_idx, ...],
                                'history_length': self.cfg.history_length,
                                'future_length': self.cfg.future_length,
                            }
                        )
                        if self.padding:
                            gender_batch[-1]['primitive_padding_mask'] = gender_seq_dict[person]['primitive_padding_mask'][start_idx:end_idx, ...]
                selector = torch.cat([torch.ones(gender_batch_size), torch.zeros(gender_batch_size)])
                selector = selector[torch.randperm(2 * gender_batch_size)]
                
                front_group, back_group = {}, {}
                for key in add_key_list:
                    front_group[key], back_group[key] = [], []
                    for d in gender_batch[:self.num_primitive]:
                        front_group[key] += d[key]
                    for d in gender_batch[self.num_primitive:]:
                        back_group[key] += d[key]
                for key in cat_key_list:
                    front_group[key] = torch.cat([d[key] for d in gender_batch[:self.num_primitive]], dim=0)
                    back_group[key] = torch.cat([d[key] for d in gender_batch[self.num_primitive:]], dim=0)

                front_indices = torch.nonzero(selector[:gender_batch_size], as_tuple=True)[0]  
                back_indices = torch.nonzero(selector[gender_batch_size:], as_tuple=True)[0]  

                selected_batch = []
                for i in range(self.num_primitive):
                    selected_dict = {'history_length': self.cfg.history_length,
                                'future_length': self.cfg.future_length,}
                    for key in front_group.keys():    
                        if key in add_key_list:
                            selected_front = [front_group[key][i] for i in front_indices + i * gender_batch_size] 
                            selected_back = [back_group[key][i] for i in back_indices + i * gender_batch_size]
                            selected_dict[key] = selected_front + selected_back
                        elif key in cat_key_list:
                            selected_front = front_group[key][front_indices + i * gender_batch_size] 
                            selected_back = back_group[key][back_indices + i * gender_batch_size]
                            selected_dict[key] = torch.cat([selected_front, selected_back], dim=0)  
                    selected_batch.append(selected_dict)
                gender_batch = selected_batch
                            
            if batch is None:
                batch = gender_batch
            else:  # concatenate different gender batch
                for primitive_idx in range(self.num_primitive):
                    if self.mode == 'merged':
                        for key_type in self.key_list:
                            if key_type != 'interaction':
                                for key in add_key_list:
                                    batch[key_type][primitive_idx][key] = batch[key_type][primitive_idx][key] + gender_batch[key_type][primitive_idx][key]
                                for key in cat_key_list:
                                    batch[key_type][primitive_idx][key] = torch.cat([batch[key_type][primitive_idx][key], gender_batch[key_type][primitive_idx][key]], dim=0)
                            else:
                                for key in ['texts']:
                                    batch[key_type][primitive_idx][key] = batch[key_type][primitive_idx][key] + gender_batch[key_type][primitive_idx][key]
                                for key in ['text_embedding', 'text_mask', 'rel_pose_b2a', 'rel_pose_a2b', 'rel_info_b2a', 'rel_info_a2b']:
                                    batch[key_type][primitive_idx][key] = torch.cat([batch[key_type][primitive_idx][key], gender_batch[key_type][primitive_idx][key]], dim=0)
                    else:
                        for key in add_key_list:
                            batch[primitive_idx][key] = batch[primitive_idx][key] + gender_batch[primitive_idx][key]
                        for key in cat_key_list:
                            batch[primitive_idx][key] = torch.cat([batch[primitive_idx][key], gender_batch[primitive_idx][key]], dim=0)
        return batch
    
    def get_item(self, idx):
        seq_data = self.dataset[idx]
        num_frames = len(seq_data['motion_p1']['transl'])
        gender = {}
        gender['person1'] = seq_data['motion_p1']['gender']
        gender['person2'] = seq_data['motion_p2']['gender']
        if 'text' in self.weight_scheme:
            start_frame = random.choices(range(num_frames - self.seq_length + 1), weights=seq_data['frame_weights'], k=1)[0]
        else:
            start_frame = random.randint(0, num_frames - self.seq_length)  # [0, num_frames - seq_length], right end inclusive
        primitive_data_list = []
        for frame_idx in range(start_frame, start_frame + self.seq_length - self.primitive_length, self.future_length):
            primitive_data = self.get_primitive(seq_data, frame_idx, frame_idx + self.primitive_length)
            primitive_data_list.append(primitive_data)
        
        gender_seq_texts = {key_type: [] for key_type in self.key_list}
        gender_seq_dict = None
        for primitive_idx in range(self.num_primitive):
            primitive_dict = {}
            for person in ['person1', 'person2']:
                primitive_dict[person] = {}
                primitive_dict[person]['gender'] = gender[person]
                for key in ['betas', 'transl', 'global_orient', 'body_pose', 'transf_rotmat', 'transf_transl', 'pelvis_delta', 'joints']:
                    primitive_dict[person][key] = primitive_data_list[primitive_idx]['primitive_dict'][person][key]
                if self.padding:
                    primitive_dict[person]['primitive_padding_mask'] = primitive_data_list[primitive_idx]['primitive_dict'][person]['primitive_padding_mask']
            primitive_texts = {}
            for key_type in self.key_list:
                primitive_texts[key_type] = primitive_data_list[primitive_idx]['text_'+key_type]
                gender_seq_texts[key_type].append(primitive_texts[key_type])
            
            if gender_seq_dict is None:
                gender_seq_dict = primitive_dict
            else:
                for person in ['person1', 'person2']:
                    for key in ['betas', 'transl', 'global_orient', 'body_pose', 'transf_rotmat', 'transf_transl', 'pelvis_delta', 'joints']:
                        gender_seq_dict[person][key] = torch.cat([gender_seq_dict[person][key], primitive_dict[person][key]], dim=0)
                    if self.padding:
                        gender_seq_dict[person]['primitive_padding_mask'] = torch.cat([gender_seq_dict[person]['primitive_padding_mask'],
                                                                                           primitive_dict[person]['primitive_padding_mask']], dim=0)

        canonicalized_primitive_dict = {}
        if self.mode == 'merged':
            transf_rotmat, transf_transl = {}, {}
        for person in ['person1', 'person2']:
            gender_seq_dict[person] = tensor_dict_to_device(gender_seq_dict[person], self.device)
            if self.mode == 'merged':
                transf_rotmat[person], transf_transl[person], canonicalized_primitive_dict[person] = self.primitive_utility.canonicalize(copy.deepcopy(gender_seq_dict[person]), use_predicted_joints=True)
            else:
                _, _, canonicalized_primitive_dict[person] = self.primitive_utility.canonicalize(copy.deepcopy(gender_seq_dict[person]), use_predicted_joints=True)
        
        if self.mode == 'merged':
            # rel_pose
            rel_rotmat, rel_transl, rel_pose = {}, {}, {}
            rel_rotmat['b2a'], rel_transl['b2a'] = self.primitive_utility.compute_rel_transform_B_in_A(transf_rotmat['person1'], transf_transl['person1'], transf_rotmat['person2'], transf_transl['person2'])
            rel_rotmat['a2b'], rel_transl['a2b'] = self.primitive_utility.compute_rel_transform_B_in_A(transf_rotmat['person2'], transf_transl['person2'], transf_rotmat['person1'], transf_transl['person1'])
            for rel in ['b2a', 'a2b']:
                rel_rotmat[rel] = transforms.matrix_to_rotation_6d(rel_rotmat[rel])
                rel_pose[rel] = torch.cat([rel_rotmat[rel], rel_transl[rel].squeeze(1)], dim=-1)  # [B*num_mp, 6+3]
            
            if self.use_interaction_model:
                # reltive transition, relative distance
                rel_global_orient, rel_root_transl, rel_mindis = {}, {}, {}
                rel_global_orient['b2a'] = gender_seq_dict['person1']['global_orient'].transpose(-1, -2) @ gender_seq_dict['person2']['global_orient']
                rel_global_orient['a2b'] = gender_seq_dict['person2']['global_orient'].transpose(-1, -2) @ gender_seq_dict['person1']['global_orient']
                rel_root_transl['b2a'] = torch.matmul(gender_seq_dict['person1']['global_orient'].transpose(-1, -2), (gender_seq_dict['person2']['transl']-gender_seq_dict['person1']['transl']).unsqueeze(-1)).squeeze(-1)
                rel_root_transl['a2b'] = torch.matmul(gender_seq_dict['person2']['global_orient'].transpose(-1, -2), (gender_seq_dict['person1']['transl']-gender_seq_dict['person2']['transl']).unsqueeze(-1)).squeeze(-1)
                
                dists = torch.norm(gender_seq_dict['person1']['joints'].unsqueeze(3)-gender_seq_dict['person2']['joints'].unsqueeze(2), dim=-1)
                rel_mindis['b2a'], _ = dists.min(dim=-1)
                rel_mindis['a2b'], _ = dists.min(dim=-2)
                
                rel_info = {}
                for key in ['b2a', 'a2b']:
                    rel_info[key] = torch.cat([transforms.matrix_to_rotation_6d(rel_global_orient[key]), rel_root_transl[key], rel_mindis[key]], dim=-1)  # [B*num_mp, T, 6+3+22]

        
        feature_dict = {}
        data_batch = {} if self.mode == 'merged' else []
        for person in ['person1', 'person2']:
            feature_dict[person] = self.primitive_utility.calc_features(canonicalized_primitive_dict[person], use_predicted_joints=True)
            feature_dict[person]['transl'] = feature_dict[person]['transl'][:, :-1, :]  # [num_mp, T, 3]
            feature_dict[person]['poses_6d'] = feature_dict[person]['poses_6d'][:, :-1, :]  # [num_mp, T, 66]
            feature_dict[person]['joints'] = feature_dict[person]['joints'][:, :-1, :]  # [num_mp, T, 22 * 3]
        
        if self.mode == 'merged':
            for person in ['person1', 'person2']:
                data_batch[person] = []
                motion_tensor_normalized = self.normalize(self.dict_to_tensor(feature_dict[person]))     # [B*num_mp, T, D]
                motion_tensor_normalized = motion_tensor_normalized.permute(0, 2, 1).unsqueeze(2)           # [B*num_mp, D, 1, T]
                history_mask = torch.zeros_like(motion_tensor_normalized, dtype=torch.bool, device=self.device)
                history_mask[..., :self.cfg.history_length] = True
                history_motion = torch.zeros_like(motion_tensor_normalized, dtype=torch.float32, device=self.device)
                history_motion[..., :self.cfg.history_length] = motion_tensor_normalized[..., :self.cfg.history_length]

                for primitive_idx in range(self.num_primitive):
                    start_idx = primitive_idx 
                    end_idx = primitive_idx + 1
                    primitive_texts = gender_seq_texts[person][start_idx:end_idx]
                    unseen_texts = [text for text in primitive_texts if text not in self.text_embedding_dict[person]]
                    if len(unseen_texts) > 0:
                        self.update_text_embedding_dict(unseen_texts, person, text_sep=self.text_sep, max_segs=self.max_segs)
                    text_embedding = torch.stack([self.text_embedding_dict[person][text] for text in primitive_texts], dim=0)  # [B, 512]
                    if self.text_sep:
                        text_mask = torch.stack([self.text_mask_dict[person][text] for text in primitive_texts], dim=0)  # [B, max_segs]
                    else:
                        text_mask = None
            
                    data_batch[person].append({
                            'texts': primitive_texts,
                            'text_embedding': text_embedding,
                            'text_mask': text_mask, 
                            'gender': [gender_seq_dict[person]['gender']],
                            'betas': gender_seq_dict[person]['betas'][start_idx:end_idx, :-1, :10],
                            'motion_tensor_normalized': motion_tensor_normalized[start_idx:end_idx, ...], # [B, D, 1, T]
                            'history_motion': history_motion[start_idx:end_idx, ...],
                            'history_mask': history_mask[start_idx:end_idx, ...],
                            'transf_rotmat': transf_rotmat[person][start_idx:end_idx, ...],
                            'transf_transl': transf_transl[person][start_idx:end_idx, ...],
                            'history_length': self.cfg.history_length,
                            'future_length': self.cfg.future_length,
                        })
                    if self.padding:
                        data_batch[person][-1]['primitive_padding_mask'] = gender_seq_dict[person]['primitive_padding_mask'][start_idx:end_idx, ...]
            data_batch['interaction'] = []
            for primitive_idx in range(self.num_primitive):        
                start_idx = primitive_idx
                end_idx = (primitive_idx + 1)
                primitive_texts = gender_seq_texts['interaction'][start_idx:end_idx]
                unseen_texts = [text for text in primitive_texts if text not in self.text_embedding_dict['interaction']]
                if len(unseen_texts) > 0:
                    self.update_text_embedding_dict(unseen_texts, 'interaction', text_sep=self.text_sep, max_segs=self.max_segs)
                text_embedding = torch.stack([self.text_embedding_dict['interaction'][text] for text in primitive_texts], dim=0)  # [B, 512]
                if self.text_sep:
                    text_mask = torch.stack([self.text_mask_dict['interaction'][text] for text in primitive_texts], dim=0)  # [B, max_segs]
                else:
                    text_mask = None
                data_batch['interaction'].append({
                        'texts': primitive_texts,
                        'text_embedding': text_embedding,
                        'text_mask': text_mask, 
                        'rel_pose_b2a': rel_pose['b2a'][start_idx:end_idx],
                        'rel_pose_a2b': rel_pose['a2b'][start_idx:end_idx],
                    })
                if self.use_interaction_model:
                    data_batch['interaction'][-1].update({
                        'rel_info_b2a': rel_info['b2a'][start_idx:end_idx, self.cfg.history_length:-1].reshape(1, -1),
                        'rel_info_a2b': rel_info['a2b'][start_idx:end_idx, self.cfg.history_length:-1].reshape(1, -1),
                    })
            return data_batch
        else:
            for person in ['person1', 'person2']:
                motion_tensor_normalized = self.normalize(self.dict_to_tensor(feature_dict[person]))     # [num_mp, T, D]
                motion_tensor_normalized = motion_tensor_normalized.permute(0, 2, 1).unsqueeze(2)           # [num_mp, D, 1, T]
                history_mask = torch.zeros_like(motion_tensor_normalized, dtype=torch.bool, device=self.device)
                history_mask[..., :self.cfg.history_length] = True
                history_motion = torch.zeros_like(motion_tensor_normalized, dtype=torch.float32, device=self.device)
                history_motion[..., :self.cfg.history_length] = motion_tensor_normalized[..., :self.cfg.history_length]

                for primitive_idx in range(self.num_primitive):
                    start_idx = primitive_idx
                    end_idx = primitive_idx + 1
                    primitive_texts = gender_seq_texts[person][start_idx:end_idx]
                    unseen_texts = [text for text in primitive_texts if text not in self.text_embedding_dict[person]]
                    if len(unseen_texts) > 0:
                        new_text_embeddings = encode_text(self.clip_model, unseen_texts)
                        for idx, text in enumerate(unseen_texts):
                            self.text_embedding_dict[person][text] = new_text_embeddings[idx]
                    text_embedding = torch.stack([self.text_embedding_dict[person][text] for text in primitive_texts], dim=0)  # [B, 512]
                    data_batch.append(
                        {
                            'texts': primitive_texts,
                            'text_embedding': text_embedding,
                            'gender': [gender_seq_dict[person]['gender']],
                            'betas': gender_seq_dict[person]['betas'][start_idx:end_idx, :-1, :10],
                            'motion_tensor_normalized': motion_tensor_normalized[start_idx:end_idx, ...], # [1, D, 1, T]
                            'history_motion': history_motion[start_idx:end_idx, ...],
                            'history_mask': history_mask[start_idx:end_idx, ...],
                            'history_length': self.cfg.history_length,
                            'future_length': self.cfg.future_length,
                        }
                    )
                    if self.padding:
                        data_batch[-1]['primitive_padding_mask'] = gender_seq_dict[person]['primitive_padding_mask'][start_idx:end_idx, ...]
            if random.random() < 0.5:
                return data_batch[:self.num_primitive]
            else:
                return data_batch[self.num_primitive:]

# dataset = InterHumanDatasetV4(dataset_path='./data/InterHuman/seq_data_single_interaction_fps30',
#                               cfg_path='./config_files/config_hydra/motion_primitive/interhuman_h2_f8_r4.yaml',
#                               split="val",
#                               enforce_gender='None',
#                               enforce_zero_beta=0,
#                               device='cuda:0',
#                               mode='merged',
#                               text_encoder='clip',
#                               text_sep=True,
#                               use_interaction_model=False,
#                               padding=True,)
# # dataset.calc_mean_std()

# batch_test = dataset.get_batch(batch_size=8)
# sample = dataset.get_item(0)

class InterHumanDatasetEval(data.Dataset):
    def __init__(self, dataset_name='interhuman',
                 dataset_path='./data/InterHuman/seq_data_single_interaction_d262_fps30_mirror_exchangeyz',
                 cfg_path='./config_files/config_hydra/motion_primitive/interhuman_h2_f8_r4.yaml',
                 prob_static=0.0,
                 weight_scheme='uniform',
                 split="test",
                 device='cuda',
                 load_data=True,
                 enforce_gender='male',
                 enforce_zero_beta = True, 
                 body_type='smplh',
                 mode = 'merged',
                 opt = 'generate', # 'generate' or 'eval'
                 text_sep = True,
                 max_segs = 20,
                 min_length = 15,
                 max_length=300,
                 motion_repr = {
                    'transl': 3,
                    'poses_6d': 22 * 6,
                    'transl_delta': 3,
                    'global_orient_delta_6d': 6,
                    'joints': 22 * 3,
                    'joints_delta': 22 * 3,
                },
                 padding=False,
                 **kwargs):
        self.dataset_name = dataset_name
        self.dataset_path = dataset_path
        self.split = split
        self.device = device
        self.enforce_gender = enforce_gender
        self.enforce_zero_beta = enforce_zero_beta
        self.mode = mode
        self.opt = opt
        self.text_sep = text_sep
        self.max_segs = max_segs
        self.min_length = min_length
        self.max_length = max_length
        
        self.cut_length = kwargs.get('cut_length', 0)
        self.clip_version = kwargs.get('clip_version', 'ViT-B/32')
        self.load_text_embedding = kwargs.get('load_text_embedding', True)
        self.use_indi_text = kwargs.get('use_indi_text', True)
        
        self.key_list = ['person1', 'person2', 'interaction'] if self.mode=='merged' else ['person1', 'person2']
        self.text_key_list = ['person1', 'person2', 'interaction'] if (self.use_indi_text or self.mode=='sep') else ['interaction']
        
        with open(cfg_path, 'r') as f:
            self.cfg = OmegaConf.load(f)
        self.target_fps = self.cfg.fps
        self.history_length = self.cfg.history_length
        self.future_length = self.cfg.future_length

        self.primitive_utility = PrimitiveUtility(device=self.device, body_type=body_type, motion_repr=motion_repr)
        self.motion_repr = self.primitive_utility.motion_repr

        if load_data:
            with open(pjoin(dataset_path, f'{split}.pkl'), 'rb') as f:
                dataset = pickle.load(f)
            
            elements_to_remove = ['7220', '7221', '6028', '7543', '6940', '4434', '7561', '4385']
            if self.cut_length > 0:
                dataset = [data for data in dataset if (data['seq_name'] not in elements_to_remove) and (data['motion_p1']['trans'].shape[0] - 1 <= self.cut_length)]
            else:
                dataset = [data for data in dataset if data['seq_name'] not in elements_to_remove]
            filtered_dataset = []
            for data in dataset:
                T = data['motion_p1']['trans'].shape[0] - 1
                if T < (self.min_length-1):
                    continue
                else:
                    if T < self.max_length:
                        pad_len = self.max_length - T
                        for person in ['motion_p1', 'motion_p2']:
                            for key in data[person].keys():
                                if key in ['trans', 'poses', 'joints', 'pose_body', 'global_orient']:
                                    last_frame = data[person][key][-1:]
                                    padding_data = np.repeat(last_frame, pad_len, axis=0)
                                    data[person][key] = np.concatenate([
                                        data[person][key],
                                        padding_data
                                    ], axis=0)

                            # padding_mask
                            data[person]['padding_mask'] = np.concatenate([
                                np.zeros(T, dtype=np.bool_),
                                np.ones(pad_len, dtype=np.bool_)
                            ], axis=0)
                    else:
                        for person in ['motion_p1', 'motion_p2']:
                            data[person]['padding_mask'] = np.zeros(T, dtype=np.bool_)
                def convert_motion(motion, gender, enforce_zero_beta):
                    betas = torch.from_numpy(motion['betas'].astype(np.float32))
                    if enforce_zero_beta:
                        betas = torch.zeros_like(betas)
                    if self.primitive_utility.feature_dim == 276:
                        poses = torch.from_numpy(motion['poses'].astype(np.float32))
                        global_orient = transforms.axis_angle_to_matrix(poses[:, :3])                       # [T, 3, 3]
                        body_pose = transforms.axis_angle_to_matrix(poses[:, 3:66].reshape(-1, 21, 3))      # [T, 21, 3, 3]
                    elif self.primitive_utility.feature_dim == 262:
                        global_orient = transforms.axis_angle_to_matrix(torch.from_numpy(motion['global_orient'].astype(np.float32)))   # [T, 3, 3]
                        body_pose = torch.from_numpy(motion['pose_body'].astype(np.float32)).reshape(-1, 21, 6)                         # [T, 21, 6]
                    transl = torch.from_numpy(motion['trans'].astype(np.float32))
                    pelvis_delta = torch.from_numpy(motion['pelvis_delta'].astype(np.float32))              # [3]
                    joints = torch.from_numpy(motion['joints'].astype(np.float32))                          # [T, 22, 3]
                    result = {
                        'gender': gender,
                        'betas': betas.unsqueeze(0).expand(1, transl.shape[0], 10),
                        'transl': transl.unsqueeze(0),
                        'global_orient': global_orient.unsqueeze(0),
                        'body_pose': body_pose.unsqueeze(0),
                        'pelvis_delta': pelvis_delta.unsqueeze(0),
                        'joints': joints.unsqueeze(0),
                        'transf_rotmat': torch.eye(3).unsqueeze(0),
                        'transf_transl': torch.zeros(1, 1, 3),
                        'padding_mask': motion['padding_mask']
                    }
                    return result
                gender_p1 = self.enforce_gender if self.enforce_gender is not None else data['motion_p1']['gender']
                gender_p2 = self.enforce_gender if self.enforce_gender is not None else data['motion_p2']['gender']
                data['motion_p1'] = convert_motion(data['motion_p1'], gender_p1, self.enforce_zero_beta)
                data['motion_p2'] = convert_motion(data['motion_p2'], gender_p2, self.enforce_zero_beta)
                filtered_dataset.append(data)  
            
            print('num of sequences: ', len(dataset))

            self.dataset = filtered_dataset
        
        self.tensor_mean_device_dict = {}
        suffix = '_padding' if padding else ''
        mean_std_path = Path(dataset_path, f'mean_std_h{self.history_length}_f{self.future_length}{suffix}.pkl')
        try:
            print(f'loading mean and std from {mean_std_path}')
            with open(mean_std_path, 'rb') as f:
                self.tensor_mean, self.tensor_std = pickle.load(f)  # [1, 1, D]
        except FileNotFoundError:
            print('Error: mean and std not found!')
            
        # load clip model, get train text embeddings
        if self.load_text_embedding:
            self.load_and_freeze_clip(clip_version=self.clip_version, device=self.device)
            self.dim_embed_text = self.clip_model.ln_final.normalized_shape[0]
            suffix = '' if self.clip_version == 'ViT-B/32' else f"_{self.clip_version.replace('/', '')}"
            self.embedding_path = {}
            embedding_path = {}
            for key_type in self.text_key_list:
                if text_sep:
                    self.embedding_path[key_type] = embedding_path[key_type] = Path(dataset_path, f'{split}_{key_type}_text_embedding_dict_textsep{suffix}.pkl')
                else:
                    self.embedding_path[key_type] = embedding_path[key_type] = Path(dataset_path, f'{split}_{key_type}_text_embedding_dict{suffix}.pkl')
            self.text_embedding_dict = {}
            if text_sep:
                self.text_mask_dict = {}
            
            for key_type in self.text_key_list:
                if embedding_path[key_type].exists():
                    print(f"Loading text_{key_type} embeddings from {embedding_path[key_type]}!")
                    with open(embedding_path[key_type], 'rb') as f:
                        self.text_embedding_dict[key_type] = pickle.load(f)
                    if text_sep:
                        with open(Path(str(embedding_path[key_type]).replace('text_embedding_dict', 'text_mask_dict')), 'rb') as f:
                            self.text_mask_dict[key_type] = pickle.load(f)
                else:
                    print('Calculating text embeddings')
                    raw_texts = []
                    for data in self.dataset:
                        if f'frame_labels_{key_type}' in data:
                            raw_texts.extend([seg['proc_label'] for seg in data['frame_labels_' + key_type]])

                    raw_texts = list(set(raw_texts))
                    num_texts = len(raw_texts)
                    print(f'num of unique texts_{key_type}: ', len(raw_texts))
                        
                    # get text embeddings by batch
                    text_embeddings = []
                    text_mask = []
                    batch_start_idx = 0
                    while batch_start_idx < num_texts:
                        batch_end_idx = min(batch_start_idx + 256, num_texts)
                        text_embeddings_temp = self.encode_text(raw_texts[batch_start_idx:batch_end_idx], text_sep=text_sep, max_segs=max_segs)
                        if text_sep:
                            text_embeddings.append(text_embeddings_temp[0])
                            text_mask.append(text_embeddings_temp[1])
                        else:
                            text_embeddings.append(text_embeddings_temp)
                        batch_start_idx = batch_end_idx
                    text_embeddings = torch.cat(text_embeddings, dim=0).detach().cpu().numpy()
                
                    self.text_embedding_dict[key_type] = {raw_texts[idx]: text_embeddings[idx] for idx in range(num_texts)}
                    if text_sep:
                        self.text_embedding_dict[key_type][''] = np.zeros((self.max_segs, self.dim_embed_text)).astype(np.float32)
                    else:
                        self.text_embedding_dict[key_type][''] = np.zeros(self.dim_embed_text).astype(np.float32)  # for empty text have zero embedding, compatible with mdm text masking
                    with open(embedding_path[key_type], 'wb') as f:
                        pickle.dump(self.text_embedding_dict[key_type], f)
                    if text_sep:
                        text_mask = torch.cat(text_mask, dim=0).detach().cpu().numpy()
                        self.text_mask_dict[key_type] = {raw_texts[idx]: text_mask[idx] for idx in range(num_texts)}
                        self.text_mask_dict[key_type][''] = np.zeros(max_segs).astype(np.bool_)  # for empty text have zero embedding, compatible with mdm text masking
                        with open(Path(str(embedding_path[key_type]).replace('text_embedding_dict', 'text_mask_dict')), 'wb') as f:
                            pickle.dump(self.text_mask_dict[key_type], f)
                
                for key in self.text_embedding_dict[key_type]:
                    self.text_embedding_dict[key_type][key] = torch.from_numpy(self.text_embedding_dict[key_type][key]).to(dtype=torch.float32, device='cpu')
                    if text_sep:
                        self.text_mask_dict[key_type][key] = torch.from_numpy(self.text_mask_dict[key_type][key]).to(dtype=torch.bool, device='cpu')

    def load_and_freeze_clip(self, clip_version, device='cpu'):
        self.clip_model, _= clip.load(clip_version, device=device,
                                                jit=False)  # Must set jit=False for training
        clip.model.convert_weights(self.clip_model)  # Actually this line is unnecessary since clip by default already on float16

        # Freeze CLIP weights
        self.clip_model.eval()
        for p in self.clip_model.parameters():
            p.requires_grad = False
    
    def encode_text(self, raw_text, force_empty_zero=True, text_sep=False, max_segs = 20, sep_mode=0):
        import pandas as pd
        device = next(self.clip_model.parameters()).device
        embed_dim = self.dim_embed_text
        batch_size = len(raw_text)

        if not text_sep:
            with torch.no_grad():
                texts = clip.tokenize(raw_text, truncate=True).to(device)  # [B, context_length]
                text_embedding = self.clip_model.encode_text(texts).float()  # [B, 512]
                if force_empty_zero:
                    empty_text = [t == '' for t in raw_text]
                    text_embedding[empty_text, :] = 0
                return text_embedding
                
        raw_series = pd.Series(raw_text).str.strip().str.rstrip('.')
        if sep_mode == 0:
            split_df = raw_series.str.split(r'[,.]', n=max_segs - 1, expand=True)
        elif sep_mode == 1:
            split_df = raw_series.str.split(r'\band\b|\bwhile\b|,|\.', n=max_segs - 1, expand=True)
        split_df = split_df.fillna('').astype(str).applymap(str.strip)

        split_df = split_df.reindex(columns=range(max_segs), fill_value='')
        
        segs_matrix = split_df.values
        segs_flat = segs_matrix.reshape(-1).tolist()

        text_mask = (segs_matrix == '').astype(bool)
        text_mask = torch.tensor(text_mask, dtype=torch.bool, device=device)

        tokenized = clip.tokenize(segs_flat, truncate=True).to(device)  # [B*max_segs, context_length]
        text_embedding = self.clip_model.encode_text(tokenized).float()      # [B*max_segs, 512]
        text_embedding = text_embedding.view(batch_size, max_segs, embed_dim)  # [B, max_segs, 512]

        if force_empty_zero:
            text_embedding[text_mask] = 0

        return text_embedding, text_mask

    def update_text_embedding_dict(self, new_texts, key_type, text_sep=False, max_segs=20):
        new_text_embeddings = self.encode_text(new_texts, text_sep=text_sep, max_segs=max_segs)
        for idx, text in enumerate(new_texts):
            if text_sep:
                self.text_embedding_dict[key_type][text] = new_text_embeddings[0][idx]
                self.text_mask_dict[key_type][text] = new_text_embeddings[1][idx]
            else:
                self.text_embedding_dict[key_type][text] = new_text_embeddings[idx]
    
    def get_mean_std_by_device(self, device):
        if device not in self.tensor_mean_device_dict:
            self.tensor_mean_device_dict[device] = (self.tensor_mean.to(device=device), self.tensor_std.to(device=device))
        return self.tensor_mean_device_dict[device]

    def normalize(self, tensor):
        tensor_mean, tensor_std = self.get_mean_std_by_device(tensor.device)
        tensor_std_safe = tensor_std.clone()
        tensor_std_safe[tensor_std == 0] = 1.0  # avoid division by zero
        return (tensor - tensor_mean) / tensor_std_safe  # [B, T, D]
    
    def denormalize(self, tensor):
        tensor_mean, tensor_std = self.get_mean_std_by_device(tensor.device)
        return tensor * tensor_std + tensor_mean  # [B, T, D]
    
    def __len__(self):
        return len(self.dataset)
        
    def __getitem__(self, idx):
        seq_data = copy.deepcopy(self.dataset[idx])
        # exchange person1 and person2
        exchange = False
        if random.random() < 0.5:
            exchange = True
            seq_data['motion_p1'], seq_data['motion_p2'] = seq_data['motion_p2'], seq_data['motion_p1']
            if self.use_indi_text:
                seq_data['frame_labels_person1'], seq_data['frame_labels_person2'] = seq_data['frame_labels_person2'], seq_data['frame_labels_person1']

        length = seq_data['motion_p1']['betas'].shape[1]
        
        # Truncate sequence if too long
        if length > self.max_length+1:
            cut_idx = random.choice(list(range(0, length - self.max_length, 1)))
            for person in ['motion_p1', 'motion_p2']:
                for key in ['betas', 'transl', 'global_orient', 'body_pose', 'joints']:
                    seq_data[person][key] = seq_data[person][key][:, cut_idx:cut_idx + self.max_length+1]
                seq_data[person]['padding_mask'] = seq_data[person]['padding_mask'][cut_idx:cut_idx + self.max_length]
        
        # Fast path for eval mode
        if self.opt == 'eval':
            if self.mode == 'merged':
                seq_data['person1'] = seq_data.pop('motion_p1')
                seq_data['person2'] = seq_data.pop('motion_p2')
                for person in ['person1', 'person2']:
                    for key in ['betas', 'transl', 'global_orient', 'body_pose', 'pelvis_delta', 'joints', 'transf_rotmat', 'transf_transl']:
                        if key in seq_data[person]:
                            seq_data[person][key] = seq_data[person][key].squeeze(0)
                interaction_text = random.choice([
                    itext['proc_label'] for itext in seq_data['frame_labels_interaction']
                ])
                return seq_data['seq_name'], interaction_text, seq_data, self.max_length - sum(seq_data['person1']['padding_mask'])
            else:
                choice = 0 if random.random() < 0.5 else 1
                person_key = 'motion_p1' if choice == 0 else 'motion_p2'
                for key in ['betas', 'transl', 'global_orient', 'body_pose', 'pelvis_delta', 'joints', 'transf_rotmat', 'transf_transl']:
                    if key in seq_data[person_key]:
                        seq_data[person_key][key] = seq_data[person_key][key].squeeze(0)
                text = random.choice([
                    itext['proc_label'] for itext in seq_data[f'frame_labels_{person_key}']
                ])
                return seq_data['seq_name'], text, seq_data[person_key], self.max_length - sum(seq_data[person_key]['padding_mask'])

        # Canonicalization
        canonicalized_primitive_dict, transf_rotmat, transf_transl = {}, {}, {}
        person_idx = {'person1': 'motion_p1', 'person2': 'motion_p2'}
        for person in ['person1', 'person2']:
            transf_rotmat[person], transf_transl[person], canonicalized_primitive_dict[person] = self.primitive_utility.canonicalize(copy.deepcopy(seq_data[person_idx[person]]), use_predicted_joints=True)
                    
        if self.mode == 'merged':
            # rel_pose
            rel_rotmat, rel_transl, rel_pose = {}, {}, {}
            rel_rotmat['b2a'], rel_transl['b2a'] = self.primitive_utility.compute_rel_transform_B_in_A(transf_rotmat['person1'], transf_transl['person1'], transf_rotmat['person2'], transf_transl['person2'])
            rel_rotmat['a2b'], rel_transl['a2b'] = self.primitive_utility.compute_rel_transform_B_in_A(transf_rotmat['person2'], transf_transl['person2'], transf_rotmat['person1'], transf_transl['person1'])
            for rel in ['b2a', 'a2b']:
                rel_rotmat[rel] = transforms.matrix_to_rotation_6d(rel_rotmat[rel])
                rel_pose[rel] = torch.cat([rel_rotmat[rel], rel_transl[rel].squeeze(1)], dim=-1)  # [1, 6+3]
        
        data_batch = {} if self.mode == 'merged' else []
        # calculate features
        feature_dict = {}
        for person in ['person1', 'person2']:
            feature_dict[person] = self.primitive_utility.calc_features(canonicalized_primitive_dict[person], use_predicted_joints=True)
            if self.primitive_utility.feature_dim == 276:
                feature_dict[person]['transl'] = feature_dict[person]['transl'][:, :-1, :]  # [1, T, 3]
                feature_dict[person]['poses_6d'] = feature_dict[person]['poses_6d'][:, :-1, :]  # [1, T, 66]
                feature_dict[person]['joints'] = feature_dict[person]['joints'][:, :-1, :]  # [1, T, 22 * 3]
        
        if self.mode == 'merged':
            for person in ['person1', 'person2']:
                motion_tensor_normalized = self.normalize(self.primitive_utility.dict_to_tensor(feature_dict[person]))      # [1, T, D]
                motion_tensor_normalized = motion_tensor_normalized.permute(0, 2, 1).unsqueeze(2)                           # [1, D, 1, T]
                
                if self.use_indi_text:    
                    texts = [random.choice([itext['proc_label'] for itext in seq_data[f'frame_labels_{person}']])]
                    # unseen_texts = [text for text in texts if text not in self.text_embedding_dict[person]]
                    # if len(unseen_texts) > 0:
                    #     self.update_text_embedding_dict(unseen_texts, person, text_sep=self.text_sep, max_segs=self.max_segs)
                    query_person = (
                        'person2' if exchange and person == 'person1' else
                        'person1' if exchange and person == 'person2' else
                        person
                    )
                    text_embedding = torch.stack([self.text_embedding_dict[query_person][text] for text in texts], dim=0)  # [1, 512]
                    text_mask = torch.stack([self.text_mask_dict[query_person][text] for text in texts], dim=0) if self.text_sep else None
                            
                data_batch[person] = {
                        'gender': seq_data[person_idx[person]]['gender'],
                        'betas': seq_data[person_idx[person]]['betas'].squeeze(0),
                        'motion_tensor_normalized': motion_tensor_normalized.squeeze(0),        # [1, D, 1, T]
                        'transf_rotmat': transf_rotmat[person].squeeze(0),
                        'transf_transl': transf_transl[person].squeeze(0),
                        'history_length': self.history_length,
                        'future_length': self.future_length,
                        'padding_mask': seq_data[person_idx[person]]['padding_mask'],
                    }
                if self.use_indi_text:
                    data_batch[person]['texts'] = texts[0]
                if self.load_text_embedding and self.use_indi_text:
                    data_batch[person]['text_embedding'] = text_embedding.squeeze(0).detach().cpu()
                    if self.text_sep:
                        data_batch[person]['text_mask'] = text_mask.squeeze(0).detach().cpu()
            texts = [random.choice([itext['proc_label'] for itext in seq_data['frame_labels_interaction']])]
            # unseen_texts = [text for text in texts if text not in self.text_embedding_dict['interaction']]
            # if len(unseen_texts) > 0:
            #     self.update_text_embedding_dict(unseen_texts, 'interaction', text_sep=self.text_sep, max_segs=self.max_segs)

            data_batch['interaction'] = {
                'texts': texts[0], 
                'rel_pose_b2a': rel_pose['b2a'].squeeze(0),
                'rel_pose_a2b': rel_pose['a2b'].squeeze(0),
            }
            if self.load_text_embedding:
                text_embedding = torch.stack([self.text_embedding_dict['interaction'][text] for text in texts], dim=0)  # [1, 512]
                text_mask = torch.stack([self.text_mask_dict['interaction'][text] for text in texts], dim=0) if self.text_sep else None
                data_batch['interaction']['text_embedding'] = text_embedding.squeeze(0).detach().cpu()
                if self.text_sep:
                    data_batch['interaction']['text_mask'] = text_mask.squeeze(0).detach().cpu()
            
            return seq_data['seq_name'], data_batch['interaction']['texts'], data_batch, self.max_length-sum(data_batch['person1']['padding_mask'])
        else:
            for person in ['person1', 'person2']:
                motion_tensor_normalized = self.normalize(self.primitive_utility.dict_to_tensor(feature_dict[person]))     # [1, T, D]
                motion_tensor_normalized = motion_tensor_normalized.permute(0, 2, 1).unsqueeze(2)           # [1, D, 1, T]
                
                texts = [random.choice([itext['proc_label'] for itext in seq_data[f'frame_labels_{person}']])]
                # unseen_texts = [text for text in texts if text not in self.text_embedding_dict[person]]
                # if len(unseen_texts) > 0:
                #     self.update_text_embedding_dict(unseen_texts, person, text_sep=self.text_sep, max_segs=self.max_segs)
                data_batch.append(
                    {
                        'texts': texts[0],
                        'gender': seq_data[person_idx[person]]['gender'],
                        'betas': seq_data[person_idx[person]]['betas'].squeeze(0),
                        'motion_tensor_normalized': motion_tensor_normalized.squeeze(0), # [1, D, 1, T]
                        'transf_rotmat': transf_rotmat[person].squeeze(0),
                        'transf_transl': transf_transl[person].squeeze(0),
                        'history_length': self.history_length,
                        'future_length': self.future_length,
                        'padding_mask': seq_data[person_idx[person]]['padding_mask'],
                    }
                )
                if self.load_text_embedding:
                    text_embedding = torch.stack([self.text_embedding_dict[person][text] for text in texts], dim=0)  # [1, 512]
                    text_mask = torch.stack([self.text_mask_dict[person][text] for text in texts], dim=0) if self.text_sep else None
                    data_batch[-1]['text_embedding'] = text_embedding.squeeze(0).detach().cpu()
                    data_batch[-1]['text_mask'] = text_mask.squeeze(0).detach().cpu()
            choice = 0 if random.random() < 0.5 else 1
            return seq_data['seq_name'], data_batch[choice]['texts'], data_batch[choice], self.max_length - sum(data_batch[choice]['padding_mask'])

class InterHumanDatasetEvalV2(data.Dataset):
    def __init__(self, dataset_name='interhuman',
                 dataset_path='./data/InterHuman/seq_data_single_interaction_d262_fps30_mirror_exchangeyz',
                 cfg_path='./config_files/config_hydra/motion_primitive/interhuman_h2_f8_r4.yaml',
                 prob_static=0.0,
                 weight_scheme='uniform',
                 split="test",
                 device='cuda',
                 load_data=True,
                 enforce_gender='male',
                 enforce_zero_beta = True, 
                 body_type='smplh',
                 mode = 'merged',
                 text_sep = True,
                 max_segs = 20,
                 min_length = 15,
                 max_length=300,
                 motion_repr = {
                    'transl': 3,
                    'poses_6d': 22 * 6,
                    'transl_delta': 3,
                    'global_orient_delta_6d': 6,
                    'joints': 22 * 3,
                    'joints_delta': 22 * 3,
                },
                 padding=False,
                 **kwargs):
        self.dataset_name = dataset_name
        self.dataset_path = dataset_path
        self.split = split
        self.device = device
        self.enforce_gender = enforce_gender
        self.enforce_zero_beta = enforce_zero_beta
        self.mode = mode
        self.text_sep = text_sep
        self.max_segs = max_segs
        self.min_length = min_length
        self.max_length = max_length
        
        self.cut_length = kwargs.get('cut_length', 0)
        self.clip_version = kwargs.get('clip_version', 'ViT-B/32')
        self.load_text_embedding = kwargs.get('load_text_embedding', False)
        self.use_indi_text = kwargs.get('use_indi_text', False)
        
        self.key_list = ['person1', 'person2', 'interaction'] if self.mode=='merged' else ['person1', 'person2']
        self.text_key_list = ['person1', 'person2', 'interaction'] if self.use_indi_text else ['interaction']
        
        with open(cfg_path, 'r') as f:
            self.cfg = OmegaConf.load(f)
        self.target_fps = self.cfg.fps
        self.history_length = self.cfg.history_length
        self.future_length = self.cfg.future_length

        self.primitive_utility = PrimitiveUtility(device=self.device, body_type=body_type, motion_repr=motion_repr)
        self.motion_repr = self.primitive_utility.motion_repr

        if load_data:
            with open(pjoin(dataset_path, f'{split}.pkl'), 'rb') as f:
                dataset = pickle.load(f)
            
            elements_to_remove = ['7220', '7221', '6028', '7543', '6940', '4434', '7561', '4385']
            if self.cut_length > 0:
                dataset = [data for data in dataset if (data['seq_name'] not in elements_to_remove) and (data['motion_p1']['trans'].shape[0] - 1 <= self.cut_length)]
            else:
                dataset = [data for data in dataset if data['seq_name'] not in elements_to_remove]
            filtered_dataset = []
            for data in dataset:
                T = data['motion_p1']['trans'].shape[0] - 1
                if T < (self.min_length-1):
                    continue
                else:
                    if T < self.max_length:
                        pad_len = self.max_length - T
                        for person in ['motion_p1', 'motion_p2']:
                            for key in data[person].keys():
                                if key in ['trans', 'poses', 'joints', 'pose_body', 'global_orient']:
                                    last_frame = data[person][key][-1:]
                                    padding_data = np.repeat(last_frame, pad_len, axis=0)
                                    data[person][key] = np.concatenate([
                                        data[person][key],
                                        padding_data
                                    ], axis=0)

                            # padding_mask
                            data[person]['padding_mask'] = np.concatenate([
                                np.zeros(T, dtype=np.bool_),
                                np.ones(pad_len, dtype=np.bool_)
                            ], axis=0)
                    else:
                        for person in ['motion_p1', 'motion_p2']:
                            data[person]['padding_mask'] = np.zeros(T, dtype=np.bool_)
                def convert_motion(motion, gender, enforce_zero_beta):
                    betas = torch.from_numpy(motion['betas'].astype(np.float32))
                    if enforce_zero_beta:
                        betas = torch.zeros_like(betas)
                    if self.primitive_utility.feature_dim == 276:
                        poses = torch.from_numpy(motion['poses'].astype(np.float32))
                        global_orient = transforms.axis_angle_to_matrix(poses[:, :3])                       # [T, 3, 3]
                        body_pose = transforms.axis_angle_to_matrix(poses[:, 3:66].reshape(-1, 21, 3))      # [T, 21, 3, 3]
                    elif self.primitive_utility.feature_dim == 262:
                        global_orient = transforms.axis_angle_to_matrix(torch.from_numpy(motion['global_orient'].astype(np.float32)))   # [T, 3, 3]
                        body_pose = torch.from_numpy(motion['pose_body'].astype(np.float32)).reshape(-1, 21, 6)                         # [T, 21, 6]
                    transl = torch.from_numpy(motion['trans'].astype(np.float32))
                    pelvis_delta = torch.from_numpy(motion['pelvis_delta'].astype(np.float32))              # [3]
                    joints = torch.from_numpy(motion['joints'].astype(np.float32))                          # [T, 22, 3]
                    result = {
                        'gender': gender,
                        'betas': betas.unsqueeze(0).expand(1, transl.shape[0], 10),
                        'transl': transl.unsqueeze(0),
                        'global_orient': global_orient.unsqueeze(0),
                        'body_pose': body_pose.unsqueeze(0),
                        'pelvis_delta': pelvis_delta.unsqueeze(0),
                        'joints': joints.unsqueeze(0),
                        'transf_rotmat': torch.eye(3).unsqueeze(0),
                        'transf_transl': torch.zeros(1, 1, 3),
                        'padding_mask': motion['padding_mask']
                    }
                    return result
                gender_p1 = self.enforce_gender if self.enforce_gender is not None else data['motion_p1']['gender']
                gender_p2 = self.enforce_gender if self.enforce_gender is not None else data['motion_p2']['gender']
                data['motion_p1'] = convert_motion(data['motion_p1'], gender_p1, self.enforce_zero_beta)
                data['motion_p2'] = convert_motion(data['motion_p2'], gender_p2, self.enforce_zero_beta)
                filtered_dataset.append(data)  
            
            print('num of sequences: ', len(dataset))

            self.dataset = filtered_dataset
        
        self.tensor_mean_device_dict = {}
        suffix = '_padding' if padding else ''
        mean_std_path = Path(dataset_path, f'mean_std_h{self.history_length}_f{self.future_length}{suffix}.pkl')
        try:
            print(f'loading mean and std from {mean_std_path}')
            with open(mean_std_path, 'rb') as f:
                self.tensor_mean, self.tensor_std = pickle.load(f)  # [1, 1, D]
        except FileNotFoundError:
            print('Error: mean and std not found!')
            
        # load clip model, get train text embeddings
        if self.load_text_embedding:
            self.load_and_freeze_clip(clip_version=self.clip_version, device=self.device)
            self.dim_embed_text = self.clip_model.ln_final.normalized_shape[0]
            suffix = '' if self.clip_version == 'ViT-B/32' else f"_{self.clip_version.replace('/', '')}"
            self.embedding_path = {}
            embedding_path = {}
            for key_type in self.text_key_list:
                if text_sep:
                    self.embedding_path[key_type] = embedding_path[key_type] = Path(dataset_path, f'{split}_{key_type}_text_embedding_dict_textsep{suffix}.pkl')
                else:
                    self.embedding_path[key_type] = embedding_path[key_type] = Path(dataset_path, f'{split}_{key_type}_text_embedding_dict{suffix}.pkl')
            self.text_embedding_dict = {}
            if text_sep:
                self.text_mask_dict = {}
            
            for key_type in self.text_key_list:
                if embedding_path[key_type].exists():
                    print(f"Loading text_{key_type} embeddings from {embedding_path[key_type]}!")
                    with open(embedding_path[key_type], 'rb') as f:
                        self.text_embedding_dict[key_type] = pickle.load(f)
                    if text_sep:
                        with open(Path(str(embedding_path[key_type]).replace('text_embedding_dict', 'text_mask_dict')), 'rb') as f:
                            self.text_mask_dict[key_type] = pickle.load(f)
                else:
                    print('Calculating text embeddings')
                    raw_texts = []
                    for data in self.dataset:
                        if f'frame_labels_{key_type}' in data:
                            raw_texts.extend([seg['proc_label'] for seg in data['frame_labels_' + key_type]])

                    raw_texts = list(set(raw_texts))
                    num_texts = len(raw_texts)
                    print(f'num of unique texts_{key_type}: ', len(raw_texts))
                        
                    # get text embeddings by batch
                    text_embeddings = []
                    text_mask = []
                    batch_start_idx = 0
                    while batch_start_idx < num_texts:
                        batch_end_idx = min(batch_start_idx + 256, num_texts)
                        text_embeddings_temp = self.encode_text(raw_texts[batch_start_idx:batch_end_idx], text_sep=text_sep, max_segs=max_segs)
                        if text_sep:
                            text_embeddings.append(text_embeddings_temp[0])
                            text_mask.append(text_embeddings_temp[1])
                        else:
                            text_embeddings.append(text_embeddings_temp)
                        batch_start_idx = batch_end_idx
                    text_embeddings = torch.cat(text_embeddings, dim=0).detach().cpu().numpy()
                
                    self.text_embedding_dict[key_type] = {raw_texts[idx]: text_embeddings[idx] for idx in range(num_texts)}
                    if text_sep:
                        self.text_embedding_dict[key_type][''] = np.zeros((self.max_segs, self.dim_embed_text)).astype(np.float32)
                    else:
                        self.text_embedding_dict[key_type][''] = np.zeros(self.dim_embed_text).astype(np.float32)  # for empty text have zero embedding, compatible with mdm text masking
                    with open(embedding_path[key_type], 'wb') as f:
                        pickle.dump(self.text_embedding_dict[key_type], f)
                    if text_sep:
                        text_mask = torch.cat(text_mask, dim=0).detach().cpu().numpy()
                        self.text_mask_dict[key_type] = {raw_texts[idx]: text_mask[idx] for idx in range(num_texts)}
                        self.text_mask_dict[key_type][''] = np.zeros(max_segs).astype(np.bool_)  # for empty text have zero embedding, compatible with mdm text masking
                        with open(Path(str(embedding_path[key_type]).replace('text_embedding_dict', 'text_mask_dict')), 'wb') as f:
                            pickle.dump(self.text_mask_dict[key_type], f)
                
                for key in self.text_embedding_dict[key_type]:
                    self.text_embedding_dict[key_type][key] = torch.from_numpy(self.text_embedding_dict[key_type][key]).to(dtype=torch.float32, device='cpu')
                    if text_sep:
                        self.text_mask_dict[key_type][key] = torch.from_numpy(self.text_mask_dict[key_type][key]).to(dtype=torch.bool, device='cpu')

    def load_and_freeze_clip(self, clip_version, device='cpu'):
        self.clip_model, _= clip.load(clip_version, device=device,
                                                jit=False)  # Must set jit=False for training
        clip.model.convert_weights(self.clip_model)  # Actually this line is unnecessary since clip by default already on float16

        # Freeze CLIP weights
        self.clip_model.eval()
        for p in self.clip_model.parameters():
            p.requires_grad = False
    
    def encode_text(self, raw_text, force_empty_zero=True, text_sep=False, max_segs = 20, sep_mode=0):
        import pandas as pd
        device = next(self.clip_model.parameters()).device
        embed_dim = self.dim_embed_text
        batch_size = len(raw_text)

        if not text_sep:
            with torch.no_grad():
                texts = clip.tokenize(raw_text, truncate=True).to(device)  # [B, context_length]
                text_embedding = self.clip_model.encode_text(texts).float()  # [B, 512]
                if force_empty_zero:
                    empty_text = [t == '' for t in raw_text]
                    text_embedding[empty_text, :] = 0
                return text_embedding
                
        raw_series = pd.Series(raw_text).str.strip().str.rstrip('.')
        if sep_mode == 0:
            split_df = raw_series.str.split(r'[,.]', n=max_segs - 1, expand=True)
        elif sep_mode == 1:
            split_df = raw_series.str.split(r'\band\b|\bwhile\b|,|\.', n=max_segs - 1, expand=True)
        split_df = split_df.fillna('').astype(str).applymap(str.strip)

        split_df = split_df.reindex(columns=range(max_segs), fill_value='')
        
        segs_matrix = split_df.values
        segs_flat = segs_matrix.reshape(-1).tolist()

        text_mask = (segs_matrix == '').astype(bool)
        text_mask = torch.tensor(text_mask, dtype=torch.bool, device=device)

        tokenized = clip.tokenize(segs_flat, truncate=True).to(device)  # [B*max_segs, context_length]
        text_embedding = self.clip_model.encode_text(tokenized).float()      # [B*max_segs, 512]
        text_embedding = text_embedding.view(batch_size, max_segs, embed_dim)  # [B, max_segs, 512]

        if force_empty_zero:
            text_embedding[text_mask] = 0

        return text_embedding, text_mask

    def update_text_embedding_dict(self, new_texts, key_type, text_sep=False, max_segs=20):
        new_text_embeddings = self.encode_text(new_texts, text_sep=text_sep, max_segs=max_segs)
        for idx, text in enumerate(new_texts):
            if text_sep:
                self.text_embedding_dict[key_type][text] = new_text_embeddings[0][idx]
                self.text_mask_dict[key_type][text] = new_text_embeddings[1][idx]
            else:
                self.text_embedding_dict[key_type][text] = new_text_embeddings[idx]
    
    def get_mean_std_by_device(self, device):
        if device not in self.tensor_mean_device_dict:
            self.tensor_mean_device_dict[device] = (self.tensor_mean.to(device=device), self.tensor_std.to(device=device))
        return self.tensor_mean_device_dict[device]

    def normalize(self, tensor):
        tensor_mean, tensor_std = self.get_mean_std_by_device(tensor.device)
        tensor_std_safe = tensor_std.clone()
        tensor_std_safe[tensor_std == 0] = 1.0  # avoid division by zero
        return (tensor - tensor_mean) / tensor_std_safe  # [B, T, D]
    
    def denormalize(self, tensor):
        tensor_mean, tensor_std = self.get_mean_std_by_device(tensor.device)
        return tensor * tensor_std + tensor_mean  # [B, T, D]
    
    def __len__(self):
        return len(self.dataset)
        
    def __getitem__(self, idx):
        seq_data = copy.deepcopy(self.dataset[idx])
        # exchange person1 and person2
        exchange = False
        if random.random() < 0.5:
            exchange = True
            seq_data['motion_p1'], seq_data['motion_p2'] = seq_data['motion_p2'], seq_data['motion_p1']
            if self.use_indi_text:
                seq_data['frame_labels_person1'], seq_data['frame_labels_person2'] = seq_data['frame_labels_person2'], seq_data['frame_labels_person1']
        seq_data['exchange'] = exchange
        length = seq_data['motion_p1']['betas'].shape[1]
        
        # Truncate sequence if too long
        if length > self.max_length+1:
            cut_idx = random.choice(list(range(0, length - self.max_length, 1)))
            for person in ['motion_p1', 'motion_p2']:
                for key in ['betas', 'transl', 'global_orient', 'body_pose', 'joints']:
                    seq_data[person][key] = seq_data[person][key][:, cut_idx:cut_idx + self.max_length+1]
                seq_data[person]['padding_mask'] = seq_data[person]['padding_mask'][cut_idx:cut_idx + self.max_length]
        
        if self.mode == 'merged':
            seq_data['person1'] = seq_data.pop('motion_p1')
            seq_data['person2'] = seq_data.pop('motion_p2')
            for person in ['person1', 'person2']:
                for key in ['betas', 'transl', 'global_orient', 'body_pose', 'pelvis_delta', 'joints', 'transf_rotmat', 'transf_transl']:
                    if key in seq_data[person]:
                        seq_data[person][key] = seq_data[person][key].squeeze(0)
            interaction_text = random.choice([
                itext['proc_label'] for itext in seq_data['frame_labels_interaction']
            ])
            if self.use_indi_text:
                for person in ['person1', 'person2']:
                    primitive_texts = random.choice([
                        itext['proc_label'] for itext in seq_data[f'frame_labels_{person}']
                    ])
                    if self.load_text_embedding:
                        unseen_texts = [text for text in primitive_texts if text not in self.text_embedding_dict[person]]
                        if len(unseen_texts) > 0:
                            self.update_text_embedding_dict(unseen_texts, person, text_sep=self.text_sep, max_segs=self.max_segs)
                        text_embedding = torch.stack([self.text_embedding_dict[person][text] for text in primitive_texts], dim=0)  # [B, 512]
                        if self.text_sep:
                            text_mask = torch.stack([self.text_mask_dict[person][text] for text in primitive_texts], dim=0)  # [B, max_segs]
                        else:
                            text_mask = None
                        seq_data[person]['text_embedding'] = text_embedding.detach().cpu()
                        seq_data[person]['text_mask'] = text_mask.detach().cpu() if self.text_sep else None
                    seq_data[person]['texts'] = primitive_texts
            else:
                if self.load_text_embedding:
                    unseen_texts = [interaction_text] if interaction_text not in self.text_embedding_dict['interaction'] else []
                    if len(unseen_texts) > 0:
                        self.update_text_embedding_dict(unseen_texts, 'interaction', text_sep=self.text_sep, max_segs=self.max_segs)
                    text_embedding = self.text_embedding_dict['interaction'][interaction_text].unsqueeze(0)
                    if self.text_sep:
                        text_mask = self.text_mask_dict['interaction'][interaction_text].unsqueeze(0)
                    else:
                        text_mask = None
                    seq_data['interaction'] = {
                        'texts': interaction_text,
                        'text_embedding': text_embedding.detach().cpu(),
                    }
                    if self.text_sep:
                        seq_data['interaction']['text_mask'] = text_mask.detach().cpu()
            return seq_data['seq_name'], interaction_text, seq_data, self.max_length - sum(seq_data['person1']['padding_mask'])
        else:
            choice = 0 if random.random() < 0.5 else 1
            person_key = 'motion_p1' if choice == 0 else 'motion_p2'
            for key in ['betas', 'transl', 'global_orient', 'body_pose', 'pelvis_delta', 'joints', 'transf_rotmat', 'transf_transl']:
                if key in seq_data[person_key]:
                    seq_data[person_key][key] = seq_data[person_key][key].squeeze(0)
            text = random.choice([
                itext['proc_label'] for itext in seq_data[f'frame_labels_{person_key}']
            ])
            return seq_data['seq_name'], text, seq_data[person_key], self.max_length - sum(seq_data[person_key]['padding_mask'])

class InterHumanDatasetEvalV2WPE(data.Dataset):
    def __init__(self, dataset_name='interhuman',
                 dataset_path='./data/InterHuman/seq_data_single_interaction_d262_fps30_mirror_exchangeyz',
                 cfg_path='./config_files/config_hydra/motion_primitive/interhuman_h2_f8_r4.yaml',
                 prob_static=0.0,
                 weight_scheme='uniform',
                 split="test",
                 device='cuda',
                 load_data=True,
                 enforce_gender='male',
                 enforce_zero_beta = True, 
                 body_type='smplh',
                 mode = 'merged',
                 text_sep = True,
                 max_segs = 20,
                 min_length = 15,
                 max_length=300,
                 motion_repr = {
                    'transl': 3,
                    'poses_6d': 22 * 6,
                    'transl_delta': 3,
                    'global_orient_delta_6d': 6,
                    'joints': 22 * 3,
                    'joints_delta': 22 * 3,
                },
                 padding=False,
                 **kwargs):
        self.dataset_name = dataset_name
        self.dataset_path = dataset_path
        self.split = split
        self.device = device
        self.enforce_gender = enforce_gender
        self.enforce_zero_beta = enforce_zero_beta
        self.mode = mode
        self.text_sep = text_sep
        self.max_segs = max_segs
        self.min_length = min_length
        self.max_length = max_length
        
        self.cut_length = kwargs.get('cut_length', 0)
        self.clip_version = kwargs.get('clip_version', 'ViT-B/32')
        self.load_text_embedding = kwargs.get('load_text_embedding', False)
        self.use_indi_text = kwargs.get('use_indi_text', False)
        
        self.key_list = ['person1', 'person2', 'interaction'] if self.mode=='merged' else ['person1', 'person2']
        self.text_key_list = ['person1', 'person2', 'interaction'] if self.use_indi_text else ['interaction']
        
        with open(cfg_path, 'r') as f:
            self.cfg = OmegaConf.load(f)
        self.target_fps = self.cfg.fps
        self.history_length = self.cfg.history_length
        self.future_length = self.cfg.future_length

        self.primitive_utility = PrimitiveUtility(device=self.device, body_type=body_type, motion_repr=motion_repr)
        self.motion_repr = self.primitive_utility.motion_repr

        if load_data:
            with open(pjoin(dataset_path, f'{split}.pkl'), 'rb') as f:
                dataset = pickle.load(f)
            
            elements_to_remove = ['7220', '7221', '6028', '7543', '6940', '4434', '7561', '4385']
            if self.cut_length > 0:
                dataset = [data for data in dataset if (data['seq_name'] not in elements_to_remove) and (data['motion_p1']['trans'].shape[0] - 1 <= self.cut_length)]
            else:
                dataset = [data for data in dataset if data['seq_name'] not in elements_to_remove]
            filtered_dataset = []
            for data in dataset:
                T = data['motion_p1']['trans'].shape[0] - 1
                if T < (self.min_length-1):
                    continue
                else:
                    if T < self.max_length:
                        pad_len = self.max_length - T
                        for person in ['motion_p1', 'motion_p2']:
                            for key in data[person].keys():
                                if key in ['trans', 'poses', 'joints', 'pose_body', 'global_orient']:
                                    last_frame = data[person][key][-1:]
                                    padding_data = np.repeat(last_frame, pad_len, axis=0)
                                    data[person][key] = np.concatenate([
                                        data[person][key],
                                        padding_data
                                    ], axis=0)

                            # padding_mask
                            data[person]['padding_mask'] = np.concatenate([
                                np.zeros(T, dtype=np.bool_),
                                np.ones(pad_len, dtype=np.bool_)
                            ], axis=0)
                    else:
                        for person in ['motion_p1', 'motion_p2']:
                            data[person]['padding_mask'] = np.zeros(T, dtype=np.bool_)
                def convert_motion(motion, gender, enforce_zero_beta):
                    betas = torch.from_numpy(motion['betas'].astype(np.float32))
                    if enforce_zero_beta:
                        betas = torch.zeros_like(betas)
                    if self.primitive_utility.feature_dim == 276:
                        poses = torch.from_numpy(motion['poses'].astype(np.float32))
                        global_orient = transforms.axis_angle_to_matrix(poses[:, :3])                       # [T, 3, 3]
                        body_pose = transforms.axis_angle_to_matrix(poses[:, 3:66].reshape(-1, 21, 3))      # [T, 21, 3, 3]
                    elif self.primitive_utility.feature_dim == 262:
                        global_orient = transforms.axis_angle_to_matrix(torch.from_numpy(motion['global_orient'].astype(np.float32)))   # [T, 3, 3]
                        body_pose = torch.from_numpy(motion['pose_body'].astype(np.float32)).reshape(-1, 21, 6)                         # [T, 21, 6]
                    transl = torch.from_numpy(motion['trans'].astype(np.float32))
                    pelvis_delta = torch.from_numpy(motion['pelvis_delta'].astype(np.float32))              # [3]
                    joints = torch.from_numpy(motion['joints'].astype(np.float32))                          # [T, 22, 3]
                    result = {
                        'gender': gender,
                        'betas': betas.unsqueeze(0).expand(1, transl.shape[0], 10),
                        'transl': transl.unsqueeze(0),
                        'global_orient': global_orient.unsqueeze(0),
                        'body_pose': body_pose.unsqueeze(0),
                        'pelvis_delta': pelvis_delta.unsqueeze(0),
                        'joints': joints.unsqueeze(0),
                        'transf_rotmat': torch.eye(3).unsqueeze(0),
                        'transf_transl': torch.zeros(1, 1, 3),
                        'padding_mask': motion['padding_mask']
                    }
                    return result
                gender_p1 = self.enforce_gender if self.enforce_gender is not None else data['motion_p1']['gender']
                gender_p2 = self.enforce_gender if self.enforce_gender is not None else data['motion_p2']['gender']
                data['motion_p1'] = convert_motion(data['motion_p1'], gender_p1, self.enforce_zero_beta)
                data['motion_p2'] = convert_motion(data['motion_p2'], gender_p2, self.enforce_zero_beta)
                filtered_dataset.append(data)  
            
            print('num of sequences: ', len(dataset))

            self.dataset = filtered_dataset
        
        self.tensor_mean_device_dict = {}
        suffix = '_padding' if padding else ''
        mean_std_path = Path(dataset_path, f'mean_std_h{self.history_length}_f{self.future_length}{suffix}.pkl')
        try:
            print(f'loading mean and std from {mean_std_path}')
            with open(mean_std_path, 'rb') as f:
                self.tensor_mean, self.tensor_std = pickle.load(f)  # [1, 1, D]
        except FileNotFoundError:
            print('Error: mean and std not found!')
            
        # load clip model, get train text embeddings
        if self.load_text_embedding:
            self.load_and_freeze_clip(clip_version=self.clip_version, device=self.device)
            self.dim_embed_text = self.clip_model.ln_final.normalized_shape[0]
            suffix = '' if self.clip_version == 'ViT-B/32' else f"_{self.clip_version.replace('/', '')}"
            self.embedding_path = {}
            embedding_path = {}
            for key_type in self.text_key_list:
                if text_sep:
                    self.embedding_path[key_type] = embedding_path[key_type] = Path(dataset_path, f'{split}_{key_type}_text_embedding_dict_textsep{suffix}.pkl')
                else:
                    self.embedding_path[key_type] = embedding_path[key_type] = Path(dataset_path, f'{split}_{key_type}_text_embedding_dict{suffix}.pkl')
            self.text_embedding_dict = {}
            if text_sep:
                self.text_mask_dict = {}
            
            for key_type in self.text_key_list:
                if embedding_path[key_type].exists():
                    print(f"Loading text_{key_type} embeddings from {embedding_path[key_type]}!")
                    with open(embedding_path[key_type], 'rb') as f:
                        self.text_embedding_dict[key_type] = pickle.load(f)
                    if text_sep:
                        with open(Path(str(embedding_path[key_type]).replace('text_embedding_dict', 'text_mask_dict')), 'rb') as f:
                            self.text_mask_dict[key_type] = pickle.load(f)
                else:
                    print('Calculating text embeddings')
                    raw_texts = []
                    for data in self.dataset:
                        if f'frame_labels_{key_type}' in data:
                            raw_texts.extend([seg['proc_label'] for seg in data['frame_labels_' + key_type]])

                    raw_texts = list(set(raw_texts))
                    num_texts = len(raw_texts)
                    print(f'num of unique texts_{key_type}: ', len(raw_texts))
                        
                    # get text embeddings by batch
                    text_embeddings = []
                    text_mask = []
                    batch_start_idx = 0
                    while batch_start_idx < num_texts:
                        batch_end_idx = min(batch_start_idx + 256, num_texts)
                        text_embeddings_temp = self.encode_text(raw_texts[batch_start_idx:batch_end_idx], text_sep=text_sep, max_segs=max_segs)
                        if text_sep:
                            text_embeddings.append(text_embeddings_temp[0])
                            text_mask.append(text_embeddings_temp[1])
                        else:
                            text_embeddings.append(text_embeddings_temp)
                        batch_start_idx = batch_end_idx
                    text_embeddings = torch.cat(text_embeddings, dim=0).detach().cpu().numpy()
                
                    self.text_embedding_dict[key_type] = {raw_texts[idx]: text_embeddings[idx] for idx in range(num_texts)}
                    if text_sep:
                        self.text_embedding_dict[key_type][''] = np.zeros((self.max_segs, self.dim_embed_text)).astype(np.float32)
                    else:
                        self.text_embedding_dict[key_type][''] = np.zeros(self.dim_embed_text).astype(np.float32)  # for empty text have zero embedding, compatible with mdm text masking
                    with open(embedding_path[key_type], 'wb') as f:
                        pickle.dump(self.text_embedding_dict[key_type], f)
                    if text_sep:
                        text_mask = torch.cat(text_mask, dim=0).detach().cpu().numpy()
                        self.text_mask_dict[key_type] = {raw_texts[idx]: text_mask[idx] for idx in range(num_texts)}
                        self.text_mask_dict[key_type][''] = np.zeros(max_segs).astype(np.bool_)  # for empty text have zero embedding, compatible with mdm text masking
                        with open(Path(str(embedding_path[key_type]).replace('text_embedding_dict', 'text_mask_dict')), 'wb') as f:
                            pickle.dump(self.text_mask_dict[key_type], f)
                
                for key in self.text_embedding_dict[key_type]:
                    self.text_embedding_dict[key_type][key] = torch.from_numpy(self.text_embedding_dict[key_type][key]).to(dtype=torch.float32, device='cpu')
                    if text_sep:
                        self.text_mask_dict[key_type][key] = torch.from_numpy(self.text_mask_dict[key_type][key]).to(dtype=torch.bool, device='cpu')

    def load_and_freeze_clip(self, clip_version, device='cpu'):
        self.clip_model, _= clip.load(clip_version, device=device,
                                                jit=False)  # Must set jit=False for training
        clip.model.convert_weights(self.clip_model)  # Actually this line is unnecessary since clip by default already on float16

        # Freeze CLIP weights
        self.clip_model.eval()
        for p in self.clip_model.parameters():
            p.requires_grad = False
    
    def encode_text(self, raw_text, force_empty_zero=True, text_sep=False, max_segs = 20, sep_mode=0):
        import pandas as pd
        device = next(self.clip_model.parameters()).device
        embed_dim = self.dim_embed_text
        batch_size = len(raw_text)

        if not text_sep:
            with torch.no_grad():
                texts = clip.tokenize(raw_text, truncate=True).to(device)  # [B, context_length]
                text_embedding = self.clip_model.encode_text(texts).float()  # [B, 512]
                if force_empty_zero:
                    empty_text = [t == '' for t in raw_text]
                    text_embedding[empty_text, :] = 0
                return text_embedding
                
        raw_series = pd.Series(raw_text).str.strip().str.rstrip('.')
        if sep_mode == 0:
            split_df = raw_series.str.split(r'[,.]', n=max_segs - 1, expand=True)
        elif sep_mode == 1:
            split_df = raw_series.str.split(r'\band\b|\bwhile\b|,|\.', n=max_segs - 1, expand=True)
        split_df = split_df.fillna('').astype(str).applymap(str.strip)

        split_df = split_df.reindex(columns=range(max_segs), fill_value='')
        
        segs_matrix = split_df.values
        segs_flat = segs_matrix.reshape(-1).tolist()

        text_mask = (segs_matrix == '').astype(bool)
        text_mask = torch.tensor(text_mask, dtype=torch.bool, device=device)

        tokenized = clip.tokenize(segs_flat, truncate=True).to(device)  # [B*max_segs, context_length]
        text_embedding = self.clip_model.encode_text(tokenized).float()      # [B*max_segs, 512]
        text_embedding = text_embedding.view(batch_size, max_segs, embed_dim)  # [B, max_segs, 512]

        if force_empty_zero:
            text_embedding[text_mask] = 0

        return text_embedding, text_mask

    def update_text_embedding_dict(self, new_texts, key_type, text_sep=False, max_segs=20):
        new_text_embeddings = self.encode_text(new_texts, text_sep=text_sep, max_segs=max_segs)
        for idx, text in enumerate(new_texts):
            if text_sep:
                self.text_embedding_dict[key_type][text] = new_text_embeddings[0][idx]
                self.text_mask_dict[key_type][text] = new_text_embeddings[1][idx]
            else:
                self.text_embedding_dict[key_type][text] = new_text_embeddings[idx]
    
    def get_mean_std_by_device(self, device):
        if device not in self.tensor_mean_device_dict:
            self.tensor_mean_device_dict[device] = (self.tensor_mean.to(device=device), self.tensor_std.to(device=device))
        return self.tensor_mean_device_dict[device]

    def normalize(self, tensor):
        tensor_mean, tensor_std = self.get_mean_std_by_device(tensor.device)
        tensor_std_safe = tensor_std.clone()
        tensor_std_safe[tensor_std == 0] = 1.0  # avoid division by zero
        return (tensor - tensor_mean) / tensor_std_safe  # [B, T, D]
    
    def denormalize(self, tensor):
        tensor_mean, tensor_std = self.get_mean_std_by_device(tensor.device)
        return tensor * tensor_std + tensor_mean  # [B, T, D]
    
    def __len__(self):
        return len(self.dataset)
        
    def __getitem__(self, idx):
        seq_data = copy.deepcopy(self.dataset[idx])
        # exchange person1 and person2
        exchange = False
        if random.random() < 0.5:
            exchange = True
            seq_data['motion_p1'], seq_data['motion_p2'] = seq_data['motion_p2'], seq_data['motion_p1']
            if self.use_indi_text:
                seq_data['frame_labels_person1'], seq_data['frame_labels_person2'] = seq_data['frame_labels_person2'], seq_data['frame_labels_person1']
        seq_data['exchange'] = exchange
        length = seq_data['motion_p1']['betas'].shape[1]
        
        for person in ['motion_p1', 'motion_p2']:
            seq_data[person]['start_frame'] = torch.tensor(0).view(1)
            seq_data[person]['total_frames'] = torch.tensor((length - sum(seq_data['motion_p1']['padding_mask']))+1).view(1)
        
        # Truncate sequence if too long
        if length > self.max_length+1:
            cut_idx = random.choice(list(range(0, length - self.max_length, 1)))
            for person in ['motion_p1', 'motion_p2']:
                seq_data[person]['start_frame'] = torch.tensor(cut_idx).view(1)
                for key in ['betas', 'transl', 'global_orient', 'body_pose', 'joints']:
                    seq_data[person][key] = seq_data[person][key][:, cut_idx:cut_idx + self.max_length+1]
                seq_data[person]['padding_mask'] = seq_data[person]['padding_mask'][cut_idx:cut_idx + self.max_length]
        
        if self.mode == 'merged':
            seq_data['person1'] = seq_data.pop('motion_p1')
            seq_data['person2'] = seq_data.pop('motion_p2')
            for person in ['person1', 'person2']:
                for key in ['betas', 'transl', 'global_orient', 'body_pose', 'pelvis_delta', 'joints', 'transf_rotmat', 'transf_transl']:
                    if key in seq_data[person]:
                        seq_data[person][key] = seq_data[person][key].squeeze(0)
            interaction_text = random.choice([
                itext['proc_label'] for itext in seq_data['frame_labels_interaction']
            ])
            if self.use_indi_text:
                for person in ['person1', 'person2']:
                    primitive_texts = random.choice([
                        itext['proc_label'] for itext in seq_data[f'frame_labels_{person}']
                    ])
                    if self.load_text_embedding:
                        unseen_texts = [text for text in primitive_texts if text not in self.text_embedding_dict[person]]
                        if len(unseen_texts) > 0:
                            self.update_text_embedding_dict(unseen_texts, person, text_sep=self.text_sep, max_segs=self.max_segs)
                        text_embedding = torch.stack([self.text_embedding_dict[person][text] for text in primitive_texts], dim=0)  # [B, 512]
                        if self.text_sep:
                            text_mask = torch.stack([self.text_mask_dict[person][text] for text in primitive_texts], dim=0)  # [B, max_segs]
                        else:
                            text_mask = None
                        seq_data[person]['text_embedding'] = text_embedding.detach().cpu()
                        seq_data[person]['text_mask'] = text_mask.detach().cpu() if self.text_sep else None
                    seq_data[person]['texts'] = primitive_texts
            else:
                if self.load_text_embedding:
                    unseen_texts = [interaction_text] if interaction_text not in self.text_embedding_dict['interaction'] else []
                    if len(unseen_texts) > 0:
                        self.update_text_embedding_dict(unseen_texts, 'interaction', text_sep=self.text_sep, max_segs=self.max_segs)
                    text_embedding = self.text_embedding_dict['interaction'][interaction_text].unsqueeze(0)
                    if self.text_sep:
                        text_mask = self.text_mask_dict['interaction'][interaction_text].unsqueeze(0)
                    else:
                        text_mask = None
                    seq_data['interaction'] = {
                        'texts': interaction_text,
                        'text_embedding': text_embedding.detach().cpu(),
                    }
                    if self.text_sep:
                        seq_data['interaction']['text_mask'] = text_mask.detach().cpu()
            return seq_data['seq_name'], interaction_text, seq_data, self.max_length - sum(seq_data['person1']['padding_mask'])
        else:
            choice = 0 if random.random() < 0.5 else 1
            person_key = 'motion_p1' if choice == 0 else 'motion_p2'
            for key in ['betas', 'transl', 'global_orient', 'body_pose', 'pelvis_delta', 'joints', 'transf_rotmat', 'transf_transl']:
                if key in seq_data[person_key]:
                    seq_data[person_key][key] = seq_data[person_key][key].squeeze(0)
            text = random.choice([
                itext['proc_label'] for itext in seq_data[f'frame_labels_{person_key}']
            ])
            return seq_data['seq_name'], text, seq_data[person_key], self.max_length - sum(seq_data[person_key]['padding_mask'])


# dataset_test = InterHumanDatasetEval(device='cuda:1', 
#                                      opt='generate',
#                                      motion_repr = {
#                                         'joints': 22 * 3,
#                                         'joints_delta': 22 * 3,
#                                         'body_pose': 21 * 6,
#                                         'feet_contact': 4,
#                                     },
#                                      load_text_embedding=False,
#                                      use_indi_text=False,)
# print(len(dataset_test))
# data_test = dataset_test[0]


class InterGenDataset(WeightedPrimitiveSequenceDataset):
    def __init__(self, dataset_name='interhuman',
                 dataset_path='./data/InterHuman/seq_data_single_interaction_d262_fps30_mirror_exchangeyz',
                 cfg_path='./config_files/config_hydra/motion_primitive/interhuman_h2_f8_r4.yaml',
                 split="train",
                 device='cuda',
                 weight_scheme='uniform',
                 prob_static=0.0,
                 enforce_gender=None,
                 enforce_zero_beta=None,
                 load_data=True,
                 text_tolerance=0.0,
                 body_type='smplh',
                 seed_only=False,
                 use_frame_weights=True,
                 mode='merged', # 'sep' or 'merged'
                 text_sep = False,
                 max_segs = 20,
                 motion_repr = {'joints': 22 * 3,
                    'joints_delta': 22 * 3,
                    'body_pose': 21 * 6,
                    'feet_contact': 4,},
                 **kwargs):
        self.dataset_name = dataset_name
        self.dataset_path = dataset_path
        self.split = split
        self.device = device
        self.weight_scheme = weight_scheme
        self.prob_static = prob_static
        self.enforce_gender = enforce_gender
        self.enforce_zero_beta = enforce_zero_beta
        print('enforce_gender: ', enforce_gender)
        print('enforce_zero_beta: ', enforce_zero_beta)
        
        self.text_tolerance = text_tolerance
        self.seed_only = seed_only
        self.mode = mode
        self.text_sep = text_sep
        self.max_segs = max_segs
        self.sep_mode = kwargs.get('sep_mode', 0)
        self.padding = kwargs.get('padding', False)
        self.normalize_relpose = kwargs.get('normalize_relpose', False)
        self.use_interaction_model = kwargs.get('use_interaction_model', False)
        self.key_list = ['person1', 'person2', 'interaction'] if self.mode=='merged' else ['person1', 'person2']
        self.feet_thre = 0.001
        self.n_joints = 22
        
        self.clip_version = kwargs.get('clip_version', 'ViT-B/32')
        self.load_text_embedding = kwargs.get('load_text_embedding', False)
        self.use_indi_text = kwargs.get('use_indi_text', False)
        self.text_key_list = ['person1', 'person2', 'interaction'] if self.use_indi_text else ['interaction']
        
        self.primitive_utility = PrimitiveUtility(device=self.device, body_type=body_type, motion_repr=motion_repr)
        self.motion_repr = self.primitive_utility.motion_repr

        # cfg_path = Path(dataset_path, 'config.yaml')
        with open(cfg_path, 'r') as f:
            self.cfg = OmegaConf.load(f)
        self.target_fps = self.cfg.fps
        # self.downsample_rate = 120 // self.target_fps
        self.history_length = self.cfg.history_length
        self.future_length = self.cfg.future_length
        self.primitive_length = self.history_length + self.future_length
        self.num_primitive = self.cfg.num_primitive
        if seed_only:
            assert self.num_primitive == 1
        self.seq_length = self.history_length + self.future_length * self.num_primitive + 1
        self.min_length = self.history_length + self.future_length + 1

        if load_data:
            with open(pjoin(dataset_path, f'{split}.pkl'), 'rb') as f:
                dataset = pickle.load(f)
            if not self.padding:
                dataset = [data for data in dataset if len(data['motion_p1']['trans']) >= self.seq_length]
            
            elements_to_remove = ['7220', '7221', '6028', '7543', '6940', '4434', '7561', '4385']
            dataset = [data for data in dataset if (data['seq_name'] not in elements_to_remove and data['motion_p1']['trans'].shape[0]>=self.min_length)]
            
            # dataset = [data for data in dataset if data['seq_name'] == '7662']

            for data in dataset:
                if self.padding:
                    T = data['motion_p1']['trans'].shape[0]
                    if T < self.seq_length:
                        pad_len = self.seq_length - T
                        for person in ['motion_p1', 'motion_p2']:
                            for key in ['trans', 'global_orient', 'pose_body', 'joints']:
                                last_frame = data[person][key][-1:]
                                padding = np.repeat(last_frame, pad_len, axis=0)
                                data[person][key] = np.concatenate([
                                    data[person][key],
                                    padding
                                ], axis=0)

                            # padding_mask
                            data[person]['padding_mask'] = np.concatenate([
                                np.zeros(T, dtype=np.bool_),
                                np.ones(pad_len, dtype=np.bool_)
                            ], axis=0)
                    else:
                        for person in ['motion_p1', 'motion_p2']:
                            data[person]['padding_mask'] = np.zeros(T, dtype=np.bool_)
                        
                def convert_motion(motion, gender, enforce_zero_beta):
                    betas = torch.from_numpy(motion['betas'].astype(np.float32))
                    if enforce_zero_beta:
                        betas = torch.zeros_like(betas)
                    transl = torch.from_numpy(motion['trans'].astype(np.float32))
                    global_orient = transforms.axis_angle_to_matrix(torch.from_numpy(motion['global_orient'].astype(np.float32)))
                    body_pose = torch.from_numpy(motion['pose_body'].astype(np.float32)).reshape(-1, 21, 6) # [T, 21, 6]            
                    pelvis_delta = torch.from_numpy(motion['pelvis_delta'].astype(np.float32))              # [3]
                    joints = torch.from_numpy(motion['joints'].astype(np.float32))                          # [T, 22, 3]
                    result = {
                        'gender': gender,
                        'betas': betas,
                        'transl': transl,
                        'global_orient': global_orient,
                        'body_pose': body_pose,
                        'pelvis_delta': pelvis_delta,
                        'joints': joints,
                    }
                    if self.padding:
                        result['padding_mask'] = motion['padding_mask']
                        
                    return result
                gender_p1 = self.enforce_gender if self.enforce_gender is not None else data['motion_p1']['gender']
                gender_p2 = self.enforce_gender if self.enforce_gender is not None else data['motion_p2']['gender']
                data['motion_p1'] = convert_motion(data['motion_p1'], gender_p1, self.enforce_zero_beta)
                data['motion_p2'] = convert_motion(data['motion_p2'], gender_p2, self.enforce_zero_beta)
            
            print('num of sequences: ', len(dataset))
            
            # assign sampling weights to each sequence
            for data in dataset:
                if 'uniform' in weight_scheme:
                    data['weight'] = 1.0
                elif 'length' in weight_scheme:
                    if self.padding and len(data['motion_p1']['transl'])==self.seq_length:
                        data['weight'] = len(data['motion_p1']['transl'])-sum(data['motion_p1']['padding_mask'])
                    else:
                        data['weight'] = len(data['motion_p1']['transl'])
            print('finish first assigning seq weights')

            # overfit using one sequence
            # if 'overfit' in weight_scheme:
            #     seq_id = int(weight_scheme.split('overfit:')[-1].split('_')[0])
            #     for idx, data in enumerate(dataset):
            #         if idx == seq_id:
            #             data['weight'] = 1.0
            #         else:
            #             data['weight'] = 0.0

            seq_weights = np.array([data['weight'] for data in dataset])
            seq_weights = seq_weights / seq_weights.sum()

            self.dataset = dataset
            self.seq_weights = seq_weights
        self._curr_test_index = 0
        
        # load or calc mean and std
        self.tensor_mean_device_dict = {}
        file_name = f'mean_std_h{self.history_length}_f{self.future_length}'
        
        suffix = '_padding' if self.padding else ''
        mean_std_path = Path(dataset_path, f'{file_name}{suffix}.pkl')
        mean_std_relpose_path = Path(dataset_path, f'{file_name}_relpose{suffix}.pkl')
        if self.use_interaction_model:
            mean_std_interaction_path = Path(dataset_path, f'{file_name}_interaction{suffix}.pkl')
        
        if mean_std_path.exists() and (not self.use_interaction_model or mean_std_interaction_path.exists()) and (not self.normalize_relpose or mean_std_relpose_path.exists()):
            print(f'loading mean and std from {mean_std_path}')
            with open(mean_std_path, 'rb') as f:
                self.tensor_mean, self.tensor_std = pickle.load(f)  # [1, 1, D]

            if self.normalize_relpose:
                print(f'loading relpose mean and std from {mean_std_relpose_path}')
                with open(mean_std_relpose_path, 'rb') as f:
                    self.relpose_mean, self.relpose_std = pickle.load(f)
            
            if self.use_interaction_model:
                print(f'loading interaction mean and std from {mean_std_interaction_path}')
                with open(mean_std_interaction_path, 'rb') as f:
                    self.rel_mean, self.rel_std = pickle.load(f)  # [1, 1, D]
        else:
            assert self.split == 'train'
            print('calculating mean and std using train split')
            result = self.calc_mean_std()

            if self.use_interaction_model:
                self.tensor_mean, self.tensor_std, self.rel_mean, self.rel_std = result
            elif self.normalize_relpose:
                self.tensor_mean, self.tensor_std, self.relpose_mean, self.relpose_std = result
            else:
                self.tensor_mean, self.tensor_std = result
                self.rel_mean, self.rel_std = None, None

            with open(mean_std_path, 'wb') as f:
                pickle.dump((self.tensor_mean.detach().cpu(), self.tensor_std.detach().cpu()), f)
            if self.normalize_relpose:
                with open(mean_std_relpose_path, 'wb') as f:
                    pickle.dump((self.relpose_mean.detach().cpu(), self.relpose_std.detach().cpu()), f)
            
            if self.use_interaction_model:
                with open(mean_std_interaction_path, 'wb') as f:
                    pickle.dump((self.rel_mean.detach().cpu(), self.rel_std.detach().cpu()), f)
    
        # load clip model, get train text embeddings
        if self.load_text_embedding:
            self.load_and_freeze_clip(clip_version=self.clip_version, device=self.device)
            self.dim_embed_text = self.clip_model.ln_final.normalized_shape[0]
            suffix = '' if self.clip_version == 'ViT-B/32' else f"_{self.clip_version.replace('/', '')}"
            self.embedding_path = {}
            embedding_path = {}
            for key_type in self.text_key_list:
                if text_sep:
                    self.embedding_path[key_type] = embedding_path[key_type] = Path(dataset_path, f'{split}_{key_type}_text_embedding_dict_textsep_sepmode{self.sep_mode}{suffix}.pkl')
                else:
                    self.embedding_path[key_type] = embedding_path[key_type] = Path(dataset_path, f'{split}_{key_type}_text_embedding_dict{suffix}.pkl')
            self.text_embedding_dict = {}
            if text_sep:
                self.text_mask_dict = {}
        
            for key_type in self.text_key_list:
                if embedding_path[key_type].exists():
                    print(f"Loading text_{key_type} embeddings from {embedding_path[key_type]}!")
                    with open(embedding_path[key_type], 'rb') as f:
                        self.text_embedding_dict[key_type] = pickle.load(f)
                    if text_sep:
                        with open(Path(str(embedding_path[key_type]).replace('text_embedding_dict', 'text_mask_dict')), 'rb') as f:
                            self.text_mask_dict[key_type] = pickle.load(f)
                else:
                    print('Calculating text embeddings')
                    raw_texts = []
                    for data in self.dataset:
                        if f'frame_labels_{key_type}' in data:
                            raw_texts.extend([seg['proc_label'] for seg in data['frame_labels_' + key_type]])

                    raw_texts = list(set(raw_texts))
                    num_texts = len(raw_texts)
                    print(f'num of unique texts_{key_type}: ', len(raw_texts))
                        
                    # get text embeddings by batch
                    text_embeddings = []
                    text_mask = []
                    batch_start_idx = 0
                    while batch_start_idx < num_texts:
                        batch_end_idx = min(batch_start_idx + 256, num_texts)
                        text_embeddings_temp = self.encode_text(raw_texts[batch_start_idx:batch_end_idx], text_sep=text_sep, max_segs=max_segs, sep_mode=self.sep_mode)
                        if text_sep:
                            text_embeddings.append(text_embeddings_temp[0])
                            text_mask.append(text_embeddings_temp[1])
                        else:
                            text_embeddings.append(text_embeddings_temp)
                        batch_start_idx = batch_end_idx
                    text_embeddings = torch.cat(text_embeddings, dim=0).detach().cpu().numpy()
                
                    self.text_embedding_dict[key_type] = {raw_texts[idx]: text_embeddings[idx] for idx in range(num_texts)}
                    if text_sep:
                        self.text_embedding_dict[key_type][''] = np.zeros((self.max_segs, self.dim_embed_text)).astype(np.float32)
                    else:
                        self.text_embedding_dict[key_type][''] = np.zeros(self.dim_embed_text).astype(np.float32)  # for empty text have zero embedding, compatible with mdm text masking
                    with open(embedding_path[key_type], 'wb') as f:
                        pickle.dump(self.text_embedding_dict[key_type], f)
                    if text_sep:
                        text_mask = torch.cat(text_mask, dim=0).detach().cpu().numpy()
                        self.text_mask_dict[key_type] = {raw_texts[idx]: text_mask[idx] for idx in range(num_texts)}
                        self.text_mask_dict[key_type][''] = np.zeros(max_segs).astype(np.bool_)  # for empty text have zero embedding, compatible with mdm text masking
                        with open(Path(str(embedding_path[key_type]).replace('text_embedding_dict', 'text_mask_dict')), 'wb') as f:
                            pickle.dump(self.text_mask_dict[key_type], f)
                
                for key in self.text_embedding_dict[key_type]:
                    self.text_embedding_dict[key_type][key] = torch.from_numpy(self.text_embedding_dict[key_type][key]).to(dtype=torch.float32, device=self.device)
                    if text_sep:
                        self.text_mask_dict[key_type][key] = torch.from_numpy(self.text_mask_dict[key_type][key]).to(dtype=torch.bool, device=self.device)

    def load_and_freeze_clip(self, clip_version, device='cpu'):
        self.clip_model, _= clip.load(clip_version, device=device,
                                                jit=False)  # Must set jit=False for training
        clip.model.convert_weights(self.clip_model)  # Actually this line is unnecessary since clip by default already on float16

        # Freeze CLIP weights
        self.clip_model.eval()
        for p in self.clip_model.parameters():
            p.requires_grad = False
    
    def encode_text(self, raw_text, force_empty_zero=True, text_sep=False, max_segs = 20, sep_mode=0):
        import pandas as pd
        device = next(self.clip_model.parameters()).device
        embed_dim = self.dim_embed_text
        batch_size = len(raw_text)

        if not text_sep:
            with torch.no_grad():
                texts = clip.tokenize(raw_text, truncate=True).to(device)  # [B, context_length]
                text_embedding = self.clip_model.encode_text(texts).float()  # [B, 512]
                if force_empty_zero:
                    empty_text = [t == '' for t in raw_text]
                    text_embedding[empty_text, :] = 0
                return text_embedding
                
        raw_series = pd.Series(raw_text).str.strip().str.rstrip('.')
        if sep_mode == 0:
            split_df = raw_series.str.split(r'[,.]', n=max_segs - 1, expand=True)
        elif sep_mode == 1:
            split_df = raw_series.str.split(r'\band\b|\bwhile\b|,|\.', n=max_segs - 1, expand=True)
        split_df = split_df.fillna('').astype(str).applymap(str.strip)

        split_df = split_df.reindex(columns=range(max_segs), fill_value='')
        
        segs_matrix = split_df.values
        segs_flat = segs_matrix.reshape(-1).tolist()

        text_mask = (segs_matrix == '').astype(bool)
        text_mask = torch.tensor(text_mask, dtype=torch.bool, device=device)

        tokenized = clip.tokenize(segs_flat, truncate=True).to(device)  # [B*max_segs, context_length]
        text_embedding = self.clip_model.encode_text(tokenized).float()      # [B*max_segs, 512]
        text_embedding = text_embedding.view(batch_size, max_segs, embed_dim)  # [B, max_segs, 512]

        if force_empty_zero:
            text_embedding[text_mask] = 0

        return text_embedding, text_mask
    
    def get_batch_idx(self, batch_size=8):
        if self.split == 'test':
            start_idx = self._curr_test_index
            end_idx = start_idx + batch_size
            batch_idx = np.arange(start_idx, min(end_idx, len(self.dataset)))
            self._curr_test_index = end_idx if end_idx < len(self.dataset) else 0
            return batch_idx
        else:
            batch_idx = np.random.choice(len(self.dataset), size=batch_size, replace=True, p=self.seq_weights)
            return batch_idx

    def update_text_embedding_dict(self, new_texts, key_type, text_sep=False, max_segs=20):
        new_text_embeddings = self.encode_text(new_texts, text_sep=text_sep, max_segs=max_segs, sep_mode=self.sep_mode)
        for idx, text in enumerate(new_texts):
            if text_sep:
                self.text_embedding_dict[key_type][text] = new_text_embeddings[0][idx]
                self.text_mask_dict[key_type][text] = new_text_embeddings[1][idx]
            else:
                self.text_embedding_dict[key_type][text] = new_text_embeddings[idx]

    def calc_mean_std(self, batch_size=512):
        if self.future_length == 1:
            batch_size = min(batch_size, 64)
        all_mp_data, all_rel_info = [], []
        for seq_data in self.dataset:
            motion_data_p1 = seq_data['motion_p1']
            num_frames = motion_data_p1['body_pose'].shape[0]
            primitive_data_list = []
            for start_frame in range(0, num_frames - self.primitive_length, self.future_length):
                end_frame = start_frame + self.primitive_length
                primitive_data_list.append(self.get_primitive(seq_data, start_frame, end_frame, skip_text=True))
                
            primitive_dict = {}
            for person in ['person1', 'person2']:
                primitive_dict[person] = {}
                primitive_dict[person]['gender'] = {primitive_data_list[0]['primitive_dict'][person]['gender']}
                for key in ['betas', 'transl', 'global_orient', 'body_pose', 'transf_rotmat', 'transf_transl', 'pelvis_delta', 'joints']:
                    primitive_dict[person][key] = torch.cat([data['primitive_dict'][person][key] for data in primitive_data_list], dim=0)
                if self.padding:
                    primitive_dict[person]['primitive_padding_mask'] = torch.cat([data['primitive_dict'][person]['primitive_padding_mask'] for data in primitive_data_list], dim=0)
                primitive_dict[person] = tensor_dict_to_device(primitive_dict[person], self.device)

            # split primitive_dict into batches
            batch_start_idx = 0
            while batch_start_idx < len(primitive_dict['person1']['body_pose']):
                batch_primitive_dict = {}
                canonicalized_primitive_dict = {}
                if self.use_interaction_model:
                    transf_rotmat, transf_transl = {}, {}
                batch_end_idx = min(batch_start_idx + batch_size, len(primitive_dict['person1']['body_pose']))
                for person in ['person1', 'person2']:
                    batch_primitive_dict[person] = {}
                    batch_primitive_dict[person] = {key: primitive_dict[person][key][batch_start_idx:batch_end_idx] for key in ['betas', 'transl', 'global_orient', 'body_pose', 'transf_rotmat', 'transf_transl', 'pelvis_delta', 'joints']}
                    batch_primitive_dict[person]['gender'] = primitive_dict[person]['gender']
                    if self.use_interaction_model:
                        transf_rotmat[person], transf_transl[person], canonicalized_primitive_dict[person] = self.primitive_utility.canonicalize(copy.deepcopy(batch_primitive_dict[person]), use_predicted_joints=True)
                    else:
                        _, _, canonicalized_primitive_dict[person] = self.primitive_utility.canonicalize(copy.deepcopy(batch_primitive_dict[person]), use_predicted_joints=True)
                
                feature_dict = {}
                for person in ['person1', 'person2']:
                    feature_dict[person] = self.primitive_utility.calc_features(canonicalized_primitive_dict[person], use_predicted_joints=True)
                    if self.primitive_utility.feature_dim == 276:
                        feature_dict[person]['transl'] = feature_dict[person]['transl'][:, :-1, :]  # [B*num_mp, T, 3]
                        feature_dict[person]['poses_6d'] = feature_dict[person]['poses_6d'][:, :-1, :]  # [B*num_mp, T, 66]
                        feature_dict[person]['joints'] = feature_dict[person]['joints'][:, :-1, :]  # [B*num_mp, T, 22 * 3]
                    motion_tensor = self.dict_to_tensor(feature_dict[person]).detach().cpu()    # [num_primitive, T, D]
                    if self.padding:
                        mask_slice = primitive_dict[person]['primitive_padding_mask'][batch_start_idx:batch_end_idx, -1]  # [B]
                        valid_indices = torch.nonzero(~mask_slice, as_tuple=True)[0].detach().cpu()  # select valid
                        motion_tensor = motion_tensor[valid_indices]
                    all_mp_data.append(motion_tensor)

                if self.use_interaction_model:
                    B, T, *_ = batch_primitive_dict['person1']['joints'].shape
                    joints_p1 = batch_primitive_dict['person1']['joints'].reshape(B, T, 22, 3)  # [B, T+1, 22, 3]
                    joints_p2 = batch_primitive_dict['person2']['joints'].reshape(B, T, 22, 3)
                    # reltive transition, relative distance
                    rel_global_orient, rel_root_transl, rel_mindis = {}, {}, {}
                    if self.primitive_utility.feature_dim == 276:
                        rel_global_orient['b2a'] = batch_primitive_dict['person1']['global_orient'].transpose(-1, -2) @ batch_primitive_dict['person2']['global_orient']
                        rel_global_orient['a2b'] = batch_primitive_dict['person2']['global_orient'].transpose(-1, -2) @ batch_primitive_dict['person1']['global_orient']
                        # rel_root_transl['b2a'] = torch.matmul(batch_primitive_dict['person1']['global_orient'].transpose(-1, -2), (batch_primitive_dict['person2']['transl']-batch_primitive_dict['person1']['transl']).unsqueeze(-1)).squeeze(-1)
                        # rel_root_transl['a2b'] = torch.matmul(batch_primitive_dict['person2']['global_orient'].transpose(-1, -2), (batch_primitive_dict['person1']['transl']-batch_primitive_dict['person2']['transl']).unsqueeze(-1)).squeeze(-1)
                        rel_root_transl['b2a'] = torch.matmul(batch_primitive_dict['person2']['transl']-batch_primitive_dict['person1']['transl'], transf_rotmat['person1'])
                        rel_root_transl['a2b'] = torch.matmul(batch_primitive_dict['person1']['transl']-batch_primitive_dict['person2']['transl'], transf_rotmat['person2'])
                    elif self.primitive_utility.feature_dim == 262:
                        rel_global_orient['b2a'], rel_global_orient['a2b'] = cal_rel_rot(joints_p1, joints_p2)
                        rel_root_transl['b2a'] = torch.matmul(joints_p2[:,:,0]-joints_p1[:,:,0], transf_rotmat['person1'])
                        rel_root_transl['a2b'] = torch.matmul(joints_p1[:,:,0]-joints_p2[:,:,0], transf_rotmat['person2']) 

                    dists = torch.norm(joints_p1.unsqueeze(3) - joints_p2.unsqueeze(2), dim=-1)  # [B, T+1, 22, 22]
                    rel_mindis = {
                        'b2a': dists.min(dim=-1).values,  # [B, T+1, 22]
                        'a2b': dists.min(dim=-2).values,
                    }

                    for key in ['b2a', 'a2b']:
                        rot_6d = transforms.matrix_to_rotation_6d(rel_global_orient[key])  # [B, T+1, 6]
                        rel_info = torch.cat([rot_6d, rel_root_transl[key], rel_mindis[key]], dim=-1)  # [B, T+1, 6+3+22]
                        rel_info = rel_info[:, self.cfg.history_length:-1, :].detach().cpu()  # [B, T-(1+history_length), D]
                        if self.padding:
                            rel_info = rel_info[valid_indices]
                        all_rel_info.append(rel_info)
                batch_start_idx = batch_end_idx

        all_mp_data = torch.cat(all_mp_data, dim=0)                 # [2*N, T, D]
        tensor_mean = all_mp_data.mean(dim=[0, 1], keepdim=True)    # [1, 1, D]
        tensor_std = all_mp_data.std(dim=[0, 1], keepdim=True)      # [1, 1, D]
        if self.use_interaction_model and len(all_rel_info) > 0:
            all_rel_info = torch.cat(all_rel_info, dim=0)  # [2N, (T-(1+history_length)), D]
            rel_mean = all_rel_info.mean(dim=[0, 1], keepdim=True)
            rel_std = all_rel_info.std(dim=[0, 1], keepdim=True)
            relpose_mean = rel_mean[...,:9]
            relpose_std = rel_std[...,:9]
        if self.normalize_relpose:
            return tensor_mean.to(self.device), tensor_std.to(self.device), relpose_mean.to(self.device), relpose_std.to(self.device)
        if self.use_interaction_model:
            return tensor_mean.to(self.device), tensor_std.to(self.device), rel_mean.to(self.device), rel_std.to(self.device)
        return tensor_mean.to(self.device), tensor_std.to(self.device)

    def get_primitive(self, seq_data, start_frame, end_frame, skip_text=False):
        """end_frame included"""
        primitive_dict = {}
        for person, motion_data in zip(['person1', 'person2'], [seq_data['motion_p1'], seq_data['motion_p2']]):
            primitive_dict[person] = {
                'gender': motion_data['gender'],
                'betas': motion_data['betas'].expand(1, self.primitive_length + 1, 10),
                'transl': motion_data['transl'][start_frame:end_frame + 1].unsqueeze(0),
                'global_orient': motion_data['global_orient'][start_frame:end_frame + 1].unsqueeze(0),
                'body_pose': motion_data['body_pose'][start_frame:end_frame + 1].unsqueeze(0),
                'pelvis_delta': motion_data['pelvis_delta'].unsqueeze(0),
                'joints': motion_data['joints'][start_frame:end_frame + 1].unsqueeze(0),
                'transf_rotmat': torch.eye(3).unsqueeze(0),
                'transf_transl': torch.zeros(1, 1, 3),
            }
            if self.padding:
                padding_mask_full = seq_data[f'motion_p{person[-1]}']['padding_mask'][start_frame:end_frame + 1]  # shape [T+1]
                history_mask = torch.tensor(padding_mask_full[:self.history_length], dtype=torch.bool)
                future_mask = padding_mask_full[self.history_length:-1]
                future_flag = torch.tensor(future_mask.any(), dtype=torch.bool)
                primitive_dict[person]['primitive_padding_mask'] = torch.cat([history_mask, future_flag.unsqueeze(0)], dim=0).unsqueeze(0) # (1, history_length + 1)
        
        texts = {key: [] for key in self.text_key_list}
        for key_type in self.text_key_list:
            if not skip_text and f'frame_labels_{key_type}' in seq_data:
                future_start = (start_frame + self.history_length) / self.target_fps
                future_end = (start_frame + self.history_length + self.future_length - 1) / self.target_fps
                # print('text tolerance: ', self.text_tolerance)
                for seg in seq_data[f'frame_labels_{key_type}']:
                    if have_overlap([seg['start_t'], seg['end_t']], [future_start - self.text_tolerance, future_end + self.text_tolerance]):
                        texts[key_type].append(seg['proc_label'])

        output = {}
        for key_type in self.text_key_list:
            output['text_'+key_type] = random.choice(texts[key_type]) if len(texts[key_type]) > 0 else ''
        output['primitive_dict'] = primitive_dict
        return output

    def get_relpose_mean_std_by_device(self, device):
        if not hasattr(self, 'relpose_mean_device_dict'):
            self.relpose_mean_device_dict = {}

        if device not in self.relpose_mean_device_dict:
            assert self.relpose_mean is not None and self.relpose_std is not None, "rel_mean/std must be computed before normalization."
            self.relpose_mean_device_dict[device] = (
                self.relpose_mean.to(device=device),
                self.relpose_std.to(device=device)
            )
        return self.relpose_mean_device_dict[device]

    def get_rel_mean_std_by_device(self, device):
        if not hasattr(self, 'rel_mean_device_dict'):
            self.rel_mean_device_dict = {}

        if device not in self.rel_mean_device_dict:
            assert self.rel_mean is not None and self.rel_std is not None, "rel_mean/std must be computed before normalization."
            self.rel_mean_device_dict[device] = (
                self.rel_mean.to(device=device),
                self.rel_std.to(device=device)
            )
        return self.rel_mean_device_dict[device]

    def normalize_rel_pose(self, rel_pose: torch.Tensor) -> torch.Tensor:
        """
        Standardize interaction feature tensor using rel_mean / rel_std
        rel_pose: Tensor of shape [B, D] or [B, T, D]
        """
        relpose_mean, relpose_std = self.get_relpose_mean_std_by_device(rel_pose.device)
        relpose_std_safe = relpose_std.clone()
        relpose_std_safe[relpose_std_safe == 0] = 1.0  # avoid division by zero
        return (rel_pose - relpose_mean) / relpose_std_safe

    def normalize_rel_info(self, rel_info: torch.Tensor) -> torch.Tensor:
        """
        Standardize interaction feature tensor using rel_mean / rel_std
        rel_info: Tensor of shape [B, D] or [B, T, D]
        """
        rel_mean, rel_std = self.get_rel_mean_std_by_device(rel_info.device)
        rel_std_safe = rel_std.clone()
        rel_std_safe[rel_std_safe == 0] = 1.0  # avoid division by zero
        return (rel_info - rel_mean) / rel_std_safe

    def get_batch(self, batch_size=8):
        self.time = time.time()
        seq_list = []
        batch_idx = self.get_batch_idx(batch_size)
        add_key_list = ['gender']
        cat_key_list = ['betas', 'motion_tensor_normalized', 'history_motion', 'history_mask', 'transf_rotmat', 'transf_transl']
        if self.padding:
            cat_key_list.append('primitive_padding_mask')
        if self.use_indi_text:
            add_key_list.append('texts')
        if self.load_text_embedding:
            add_key_list.append('texts')
            cat_key_list.append('text_embedding')
            if self.text_sep:
                cat_key_list.append('text_mask')
        
        for seq_idx in batch_idx:
            seq_data = dict(self.dataset[seq_idx])
            # exchange person1 and person2
            if random.random() < 0.5:
                seq_data['motion_p1'], seq_data['motion_p2'] = seq_data['motion_p2'], seq_data['motion_p1']
                if self.use_indi_text:
                    seq_data['frame_labels_person1'], seq_data['frame_labels_person2'] = seq_data['frame_labels_person2'], seq_data['frame_labels_person1']
            num_frames = len(seq_data['motion_p1']['transl'])
            if 'text' in self.weight_scheme:
                start_frame = random.choices(range(num_frames - self.seq_length + 1), weights=seq_data['frame_weights'], k=1)[0]
            else:
                start_frame = random.randint(0, num_frames - self.seq_length)  # [0, num_frames - seq_length], right end inclusive
            primitive_data_list = []
            for frame_idx in range(start_frame, start_frame + self.seq_length - self.primitive_length, self.future_length):
                primitive_data = self.get_primitive(seq_data, frame_idx, frame_idx + self.primitive_length)
                primitive_data_list.append(primitive_data)
            seq_list.append(primitive_data_list)

        # sort batch by gender
        batch = None
        for gender in ['female', 'male', 'neutral']:
            gender_idx = [idx for idx in range(len(seq_list)) if seq_list[idx][0]['primitive_dict']['person1']['gender'] == gender]
            if len(gender_idx) == 0:
                continue
            gender_seq_list = [seq_list[i] for i in gender_idx]
            gender_batch_size = len(gender_idx)
            gender_batch = {} if self.mode == 'merged' else []
            
            gender_seq_texts = {key_type: None for key_type in self.text_key_list}
            gender_seq_dict = None
            for primitive_idx in range(self.num_primitive):
                primitive_dict = {}
                for person in ['person1', 'person2']:
                    primitive_dict[person] = {}
                    primitive_dict[person]['gender'] = gender
                    for key in ['betas', 'transl', 'global_orient', 'body_pose', 'transf_rotmat', 'transf_transl', 'pelvis_delta', 'joints']:
                        primitive_dict[person][key] = torch.cat([mp_seq[primitive_idx]['primitive_dict'][person][key] for mp_seq in gender_seq_list], dim=0)
                    if self.padding:
                        primitive_dict[person]['primitive_padding_mask'] = torch.cat([mp_seq[primitive_idx]['primitive_dict'][person]['primitive_padding_mask'] for mp_seq in gender_seq_list], dim=0)
                primitive_texts = {}
                for key_type in self.text_key_list:
                    primitive_texts[key_type] = [mp_seq[primitive_idx]['text_'+key_type] for mp_seq in gender_seq_list]
                    gender_seq_texts[key_type] = primitive_texts[key_type] if gender_seq_texts[key_type] is None else gender_seq_texts[key_type] + primitive_texts[key_type]
                
                if gender_seq_dict is None:
                    gender_seq_dict = primitive_dict
                else:
                    for person in ['person1', 'person2']:
                        for key in ['betas', 'transl', 'global_orient', 'body_pose', 'transf_rotmat', 'transf_transl', 'pelvis_delta', 'joints']:
                            gender_seq_dict[person][key] = torch.cat([gender_seq_dict[person][key], primitive_dict[person][key]], dim=0)
                        if self.padding:
                            gender_seq_dict[person]['primitive_padding_mask'] = torch.cat([gender_seq_dict[person]['primitive_padding_mask'], 
                                                                                           primitive_dict[person]['primitive_padding_mask']], dim=0)

            canonicalized_primitive_dict = {}
            transf_rotmat, transf_transl = {}, {}
            for person in ['person1', 'person2']:
                gender_seq_dict[person] = tensor_dict_to_device(gender_seq_dict[person], self.device)
                transf_rotmat[person], transf_transl[person], canonicalized_primitive_dict[person] = self.primitive_utility.canonicalize(copy.deepcopy(gender_seq_dict[person]), use_predicted_joints=True)            
            if self.mode == 'merged':
                # rel_pose
                rel_rotmat, rel_transl, rel_pose = {}, {}, {}
                rel_rotmat['b2a'], rel_transl['b2a'] = self.primitive_utility.compute_rel_transform_B_in_A(transf_rotmat['person1'], transf_transl['person1'], transf_rotmat['person2'], transf_transl['person2'])
                rel_rotmat['a2b'], rel_transl['a2b'] = self.primitive_utility.compute_rel_transform_B_in_A(transf_rotmat['person2'], transf_transl['person2'], transf_rotmat['person1'], transf_transl['person1'])
                for rel in ['b2a', 'a2b']:
                    rel_rotmat[rel] = transforms.matrix_to_rotation_6d(rel_rotmat[rel])
                    rel_pose[rel] = torch.cat([rel_rotmat[rel], rel_transl[rel].squeeze(1)], dim=-1)  # [B*num_mp, 6+3]
            
                if self.use_interaction_model:
                    B, T, *_ = gender_seq_dict['person1']['joints'].shape
                    # reltive transition, relative distance
                    rel_global_orient, rel_root_transl, rel_mindis = {}, {}, {}
                    if self.primitive_utility.feature_dim == 276:
                        rel_global_orient['b2a'] = gender_seq_dict['person1']['global_orient'].transpose(-1, -2) @ gender_seq_dict['person2']['global_orient']
                        rel_global_orient['a2b'] = gender_seq_dict['person2']['global_orient'].transpose(-1, -2) @ gender_seq_dict['person1']['global_orient']
                        # rel_root_transl['b2a'] = torch.matmul(gender_seq_dict['person1']['global_orient'].transpose(-1, -2), (gender_seq_dict['person2']['transl']-gender_seq_dict['person1']['transl']).unsqueeze(-1)).squeeze(-1)
                        # rel_root_transl['a2b'] = torch.matmul(gender_seq_dict['person2']['global_orient'].transpose(-1, -2), (gender_seq_dict['person1']['transl']-gender_seq_dict['person2']['transl']).unsqueeze(-1)).squeeze(-1)
                        rel_root_transl['b2a'] = torch.matmul((gender_seq_dict['person2']['transl']-gender_seq_dict['person1']['transl']), transf_rotmat['person1'])
                        rel_root_transl['a2b'] = torch.matmul((gender_seq_dict['person1']['transl']-gender_seq_dict['person2']['transl']), transf_rotmat['person2'])
                    elif self.primitive_utility.feature_dim == 262:
                        rel_global_orient['b2a'], rel_global_orient['a2b'] = cal_rel_rot(gender_seq_dict['person1']['joints'].reshape(B, T, 22, 3), gender_seq_dict['person2']['joints'].reshape(B, T, 22, 3))
                        rel_root_transl['b2a'] = torch.matmul(gender_seq_dict['person2']['joints'].reshape(B, T, 22, 3)[:,:,0]-gender_seq_dict['person1']['joints'].reshape(B, T, 22, 3)[:,:,0], transf_rotmat['person1'])
                        rel_root_transl['a2b'] = torch.matmul(gender_seq_dict['person1']['joints'].reshape(B, T, 22, 3)[:,:,0]-gender_seq_dict['person2']['joints'].reshape(B, T, 22, 3)[:,:,0], transf_rotmat['person2']) 
                    dists = torch.norm(gender_seq_dict['person1']['joints'].reshape(B, T, 22, 3).unsqueeze(3)-gender_seq_dict['person2']['joints'].reshape(B, T, 22, 3).unsqueeze(2), dim=-1)
                    rel_mindis['b2a'], _ = dists.min(dim=-1)
                    rel_mindis['a2b'], _ = dists.min(dim=-2)
                    
                    rel_info = {}
                    for key in ['b2a', 'a2b']:
                        rel_info[key] = torch.cat([transforms.matrix_to_rotation_6d(rel_global_orient[key]), rel_root_transl[key], rel_mindis[key]], dim=-1)  # [B*num_mp, T, 6+3+22]
            
            # calc features
            feature_dict = {}
            for person in ['person1', 'person2']:
                feature_dict[person] = self.primitive_utility.calc_features(canonicalized_primitive_dict[person], use_predicted_joints=True)
                if self.primitive_utility.feature_dim == 276:
                    feature_dict[person]['transl'] = feature_dict[person]['transl'][:, :-1, :]  # [B*num_mp, T, 3]
                    feature_dict[person]['poses_6d'] = feature_dict[person]['poses_6d'][:, :-1, :]  # [B*num_mp, T, 66]
                    feature_dict[person]['joints'] = feature_dict[person]['joints'][:, :-1, :]  # [B*num_mp, T, 22 * 3]   
                            
            if self.mode == 'merged':
                for person in ['person1', 'person2']:
                    gender_batch[person] = []
                    motion_tensor_normalized = self.normalize(self.dict_to_tensor(feature_dict[person]))     # [B*num_mp, T, D]
                    motion_tensor_normalized = motion_tensor_normalized.permute(0, 2, 1).unsqueeze(2)           # [B*num_mp, D, 1, T]
                    history_mask = torch.zeros_like(motion_tensor_normalized, dtype=torch.bool, device=self.device)
                    history_mask[..., :self.cfg.history_length] = True
                    history_motion = torch.zeros_like(motion_tensor_normalized, dtype=torch.float32, device=self.device)
                    history_motion[..., :self.cfg.history_length] = motion_tensor_normalized[..., :self.cfg.history_length]

                    for primitive_idx in range(self.num_primitive):
                        start_idx = primitive_idx * gender_batch_size
                        end_idx = (primitive_idx + 1) * gender_batch_size
                        if self.use_indi_text:
                            primitive_texts = gender_seq_texts[person][start_idx:end_idx]
                            if self.load_text_embedding:
                                unseen_texts = [text for text in primitive_texts if text not in self.text_embedding_dict[person]]
                                if len(unseen_texts) > 0:
                                    self.update_text_embedding_dict(unseen_texts, person, text_sep=self.text_sep, max_segs=self.max_segs)
                                text_embedding = torch.stack([self.text_embedding_dict[person][text] for text in primitive_texts], dim=0)  # [B, 512]
                                if self.text_sep:
                                    text_mask = torch.stack([self.text_mask_dict[person][text] for text in primitive_texts], dim=0)  # [B, max_segs]
                                else:
                                    text_mask = None
               
                        gender_batch[person].append({
                                'gender': [gender_seq_dict[person]['gender']] * gender_batch_size,
                                'betas': gender_seq_dict[person]['betas'][start_idx:end_idx, :-1, :10],
                                'motion_tensor_normalized': motion_tensor_normalized[start_idx:end_idx, ...], # [B, D, 1, T]
                                'history_motion': history_motion[start_idx:end_idx, ...],
                                'history_mask': history_mask[start_idx:end_idx, ...],
                                'transf_rotmat': transf_rotmat[person][start_idx:end_idx, ...],
                                'transf_transl': transf_transl[person][start_idx:end_idx, ...],
                                'history_length': self.cfg.history_length,
                                'future_length': self.cfg.future_length,
                            })
                        if self.use_indi_text:
                            gender_batch[person][-1]['texts'] = primitive_texts
                            if self.load_text_embedding:
                                gender_batch[person][-1]['text_embedding'] = text_embedding
                                gender_batch[person][-1]['text_mask'] = text_mask
                        if self.padding:
                            gender_batch[person][-1]['primitive_padding_mask'] = gender_seq_dict[person]['primitive_padding_mask'][start_idx:end_idx, ...]
                gender_batch['interaction'] = []
                for primitive_idx in range(self.num_primitive):
                    start_idx = primitive_idx * gender_batch_size
                    end_idx = (primitive_idx + 1) * gender_batch_size
                    primitive_texts = gender_seq_texts['interaction'][start_idx:end_idx]
                    if self.load_text_embedding:
                        unseen_texts = [text for text in primitive_texts if text not in self.text_embedding_dict['interaction']]
                        if len(unseen_texts) > 0:
                            self.update_text_embedding_dict(unseen_texts, 'interaction', text_sep=self.text_sep, max_segs=self.max_segs)
                        text_embedding = torch.stack([self.text_embedding_dict['interaction'][text] for text in primitive_texts], dim=0)  # [B, 512]
                        if self.text_sep:
                            text_mask = torch.stack([self.text_mask_dict['interaction'][text] for text in primitive_texts], dim=0)  # [B, max_segs]
                        else:
                            text_mask = None
                    gender_batch['interaction'].append({
                            'texts': primitive_texts, 
                            'rel_pose_b2a': rel_pose['b2a'][start_idx:end_idx],
                            'rel_pose_a2b': rel_pose['a2b'][start_idx:end_idx],
                        })
                    if self.load_text_embedding:
                        gender_batch['interaction'][-1]['text_embedding'] = text_embedding
                        gender_batch['interaction'][-1]['text_mask'] = text_mask
                    if self.normalize_relpose:
                        gender_batch['interaction'][-1]['rel_pose_b2a'] = self.normalize_rel_pose(gender_batch['interaction'][-1]['rel_pose_b2a'])
                        gender_batch['interaction'][-1]['rel_pose_a2b'] = self.normalize_rel_pose(gender_batch['interaction'][-1]['rel_pose_a2b'])
                    if self.use_interaction_model:
                        gender_batch['interaction'][-1].update({
                            'rel_info_b2a': self.normalize_rel_info(rel_info['b2a'][start_idx:end_idx, self.cfg.history_length:-1]),
                            'rel_info_a2b': self.normalize_rel_info(rel_info['a2b'][start_idx:end_idx, self.cfg.history_length:-1]),
                        })
                    
            elif self.mode == 'sep':
                for person in ['person1', 'person2']:
                    motion_tensor_normalized = self.normalize(self.dict_to_tensor(feature_dict[person]))
                    motion_tensor_normalized = motion_tensor_normalized.permute(0, 2, 1).unsqueeze(2)                   # [B*num_mp, D, 1, T]
                    history_mask = torch.zeros_like(motion_tensor_normalized, dtype=torch.bool, device=self.device)
                    history_mask[..., :self.cfg.history_length] = True
                    history_motion = torch.zeros_like(motion_tensor_normalized, dtype=torch.float32, device=self.device)
                    history_motion[..., :self.cfg.history_length] = motion_tensor_normalized[..., :self.cfg.history_length]

                    for primitive_idx in range(self.num_primitive):
                        start_idx = primitive_idx * gender_batch_size
                        end_idx = (primitive_idx + 1) * gender_batch_size
                        if self.use_indi_text:
                            primitive_texts = gender_seq_texts[person][start_idx:end_idx]
                            if self.load_text_embedding:
                                unseen_texts = [text for text in primitive_texts if text not in self.text_embedding_dict[person]]
                                if len(unseen_texts) > 0:
                                    self.update_text_embedding_dict(unseen_texts, person, text_sep=self.text_sep, max_segs=self.max_segs)
                                text_embedding = torch.stack([self.text_embedding_dict[person][text] for text in primitive_texts], dim=0)  # [B, 512]
                                if self.text_sep:
                                    text_mask = torch.stack([self.text_mask_dict[person][text] for text in primitive_texts], dim=0)  # [B, max_segs]
                                else:
                                    text_mask = None
                        else:
                            primitive_texts = gender_seq_texts['interaction'][start_idx:end_idx]
                            unseen_texts = [text for text in primitive_texts if text not in self.text_embedding_dict['interaction']]
                            if len(unseen_texts) > 0:
                                self.update_text_embedding_dict(unseen_texts, 'interaction', text_sep=self.text_sep, max_segs=self.max_segs)
                            text_embedding = torch.stack([self.text_embedding_dict['interaction'][text] for text in primitive_texts], dim=0)  # [B, 512]
                            if self.text_sep:
                                text_mask = torch.stack([self.text_mask_dict['interaction'][text] for text in primitive_texts], dim=0)  # [B, max_segs]
                            else:
                                text_mask = None
                                
                        gender_batch.append(
                            {
                                'gender': [gender_seq_dict[person]['gender']] * gender_batch_size,
                                'betas': gender_seq_dict[person]['betas'][start_idx:end_idx, :-1, :10],
                                'motion_tensor_normalized': motion_tensor_normalized[start_idx:end_idx, ...], # [B, D, 1, T]
                                'history_motion': history_motion[start_idx:end_idx, ...],
                                'history_mask': history_mask[start_idx:end_idx, ...],
                                'transf_rotmat': transf_rotmat[person][start_idx:end_idx, ...],
                                'transf_transl': transf_transl[person][start_idx:end_idx, ...],
                                'history_length': self.cfg.history_length,
                                'future_length': self.cfg.future_length,
                            }
                        )
                        if self.use_indi_text:
                            gender_batch[-1]['texts'] = primitive_texts
                        if self.load_text_embedding:
                            gender_batch[-1]['texts'] = primitive_texts
                            gender_batch[-1]['text_embedding'] = text_embedding
                            gender_batch[-1]['text_mask'] = text_mask
                        if self.padding:
                            gender_batch[-1]['primitive_padding_mask'] = gender_seq_dict[person]['primitive_padding_mask'][start_idx:end_idx, ...]
                selector = torch.cat([torch.ones(gender_batch_size), torch.zeros(gender_batch_size)])
                selector = selector[torch.randperm(2 * gender_batch_size)]
                
                front_group, back_group = {}, {}
                for key in add_key_list:
                    front_group[key], back_group[key] = [], []
                    for d in gender_batch[:self.num_primitive]:
                        front_group[key] += d[key]
                    for d in gender_batch[self.num_primitive:]:
                        back_group[key] += d[key]
                for key in cat_key_list:
                    front_group[key] = torch.cat([d[key] for d in gender_batch[:self.num_primitive]], dim=0)
                    back_group[key] = torch.cat([d[key] for d in gender_batch[self.num_primitive:]], dim=0)

                front_indices = torch.nonzero(selector[:gender_batch_size], as_tuple=True)[0]  
                back_indices = torch.nonzero(selector[gender_batch_size:], as_tuple=True)[0]  

                selected_batch = []
                for i in range(self.num_primitive):
                    selected_dict = {'history_length': self.cfg.history_length,
                                'future_length': self.cfg.future_length,}
                    for key in front_group.keys():    
                        if key in add_key_list:
                            selected_front = [front_group[key][i] for i in front_indices + i * gender_batch_size] 
                            selected_back = [back_group[key][i] for i in back_indices + i * gender_batch_size]
                            selected_dict[key] = selected_front + selected_back
                        elif key in cat_key_list:
                            selected_front = front_group[key][front_indices + i * gender_batch_size] 
                            selected_back = back_group[key][back_indices + i * gender_batch_size]
                            selected_dict[key] = torch.cat([selected_front, selected_back], dim=0)  
                    selected_batch.append(selected_dict)
                gender_batch = selected_batch
                            
            if batch is None:
                batch = gender_batch
            else:  # concatenate different gender batch
                for primitive_idx in range(self.num_primitive):
                    if self.mode == 'merged':
                        for key_type in self.key_list:
                            if key_type != 'interaction':
                                for key in add_key_list:
                                    batch[key_type][primitive_idx][key] = batch[key_type][primitive_idx][key] + gender_batch[key_type][primitive_idx][key]
                                for key in cat_key_list:
                                    batch[key_type][primitive_idx][key] = torch.cat([batch[key_type][primitive_idx][key], gender_batch[key_type][primitive_idx][key]], dim=0)
                            else:
                                for key in ['texts']:
                                    batch[key_type][primitive_idx][key] = batch[key_type][primitive_idx][key] + gender_batch[key_type][primitive_idx][key]
                                for key in ['rel_pose_b2a', 'rel_pose_a2b', 'rel_info_b2a', 'rel_info_a2b']:
                                    batch[key_type][primitive_idx][key] = torch.cat([batch[key_type][primitive_idx][key], gender_batch[key_type][primitive_idx][key]], dim=0)
                                if self.load_text_embedding:
                                    batch[key_type][primitive_idx]['text_embedding'] = torch.cat([batch[key_type][primitive_idx]['text_embedding'], gender_batch[key_type][primitive_idx]['text_embedding']], dim=0)
                                    batch[key_type][primitive_idx]['text_mask'] = torch.cat([batch[key_type][primitive_idx]['text_mask'], gender_batch[key_type][primitive_idx]['text_mask']], dim=0)   
                    else:
                        for key in add_key_list:
                            batch[primitive_idx][key] = batch[primitive_idx][key] + gender_batch[primitive_idx][key]
                        for key in cat_key_list:
                            batch[primitive_idx][key] = torch.cat([batch[primitive_idx][key], gender_batch[primitive_idx][key]], dim=0)
            # if self.mode == 'merged':
            #     if random.random() < 0.5:
            #         batch['person1'], batch['person2'] = batch['person2'], batch['person1']

        return batch
    
    def get_item(self, idx):
        seq_data = self.dataset[idx]
        num_frames = len(seq_data['motion_p1']['transl'])
        gender = {}
        gender['person1'] = seq_data['motion_p1']['gender']
        gender['person2'] = seq_data['motion_p2']['gender']
        if 'text' in self.weight_scheme:
            start_frame = random.choices(range(num_frames - self.seq_length + 1), weights=seq_data['frame_weights'], k=1)[0]
        else:
            start_frame = random.randint(0, num_frames - self.seq_length)  # [0, num_frames - seq_length], right end inclusive
        primitive_data_list = []
        for frame_idx in range(start_frame, start_frame + self.seq_length - self.primitive_length, self.future_length):
            primitive_data = self.get_primitive(seq_data, frame_idx, frame_idx + self.primitive_length)
            primitive_data_list.append(primitive_data)
        
        gender_seq_texts = {key_type: [] for key_type in self.text_key_list}
        gender_seq_dict = None
        for primitive_idx in range(self.num_primitive):
            primitive_dict = {}
            for person in ['person1', 'person2']:
                primitive_dict[person] = {}
                primitive_dict[person]['gender'] = gender[person]
                for key in ['betas', 'transl', 'global_orient', 'body_pose', 'transf_rotmat', 'transf_transl', 'pelvis_delta', 'joints']:
                    primitive_dict[person][key] = primitive_data_list[primitive_idx]['primitive_dict'][person][key]
                if self.padding:
                    primitive_dict[person]['primitive_padding_mask'] = primitive_data_list[primitive_idx]['primitive_dict'][person]['primitive_padding_mask']
            primitive_texts = {}
            for key_type in self.text_key_list:
                primitive_texts[key_type] = primitive_data_list[primitive_idx]['text_'+key_type]
                gender_seq_texts[key_type].append(primitive_texts[key_type])
            
            if gender_seq_dict is None:
                gender_seq_dict = primitive_dict
            else:
                for person in ['person1', 'person2']:
                    for key in ['betas', 'transl', 'global_orient', 'body_pose', 'transf_rotmat', 'transf_transl', 'pelvis_delta', 'joints']:
                        gender_seq_dict[person][key] = torch.cat([gender_seq_dict[person][key], primitive_dict[person][key]], dim=0)
                    if self.padding:
                        gender_seq_dict[person]['primitive_padding_mask'] = torch.cat([gender_seq_dict[person]['primitive_padding_mask'],
                                                                                           primitive_dict[person]['primitive_padding_mask']], dim=0)

        canonicalized_primitive_dict = {}
        if self.mode == 'merged':
            transf_rotmat, transf_transl = {}, {}
        for person in ['person1', 'person2']:
            gender_seq_dict[person] = tensor_dict_to_device(gender_seq_dict[person], self.device)
            if self.mode == 'merged':
                transf_rotmat[person], transf_transl[person], canonicalized_primitive_dict[person] = self.primitive_utility.canonicalize(copy.deepcopy(gender_seq_dict[person]), use_predicted_joints=True)            
            else:
                _, _, canonicalized_primitive_dict[person] = self.primitive_utility.canonicalize(copy.deepcopy(gender_seq_dict[person]), use_predicted_joints=True)          
        
        if self.mode == 'merged':
            # rel_pose
            rel_rotmat, rel_transl, rel_pose = {}, {}, {}
            rel_rotmat['b2a'], rel_transl['b2a'] = self.primitive_utility.compute_rel_transform_B_in_A(transf_rotmat['person1'], transf_transl['person1'], transf_rotmat['person2'], transf_transl['person2'])
            rel_rotmat['a2b'], rel_transl['a2b'] = self.primitive_utility.compute_rel_transform_B_in_A(transf_rotmat['person2'], transf_transl['person2'], transf_rotmat['person1'], transf_transl['person1'])
            for rel in ['b2a', 'a2b']:
                rel_rotmat[rel] = transforms.matrix_to_rotation_6d(rel_rotmat[rel])
                rel_pose[rel] = torch.cat([rel_rotmat[rel], rel_transl[rel].squeeze(1)], dim=-1)  # [B*num_mp, 6+3]
            
            if self.use_interaction_model:
                B, T, *_ = gender_seq_dict['person1']['joints'].shape
                # reltive transition, relative distance
                rel_global_orient, rel_root_transl, rel_mindis = {}, {}, {}
                if self.primitive_utility.feature_dim == 276:
                    rel_global_orient['b2a'] = gender_seq_dict['person1']['global_orient'].transpose(-1, -2) @ gender_seq_dict['person2']['global_orient']
                    rel_global_orient['a2b'] = gender_seq_dict['person2']['global_orient'].transpose(-1, -2) @ gender_seq_dict['person1']['global_orient']
                    # rel_root_transl['b2a'] = torch.matmul(gender_seq_dict['person1']['global_orient'].transpose(-1, -2), (gender_seq_dict['person2']['transl']-gender_seq_dict['person1']['transl']).unsqueeze(-1)).squeeze(-1)
                    # rel_root_transl['a2b'] = torch.matmul(gender_seq_dict['person2']['global_orient'].transpose(-1, -2), (gender_seq_dict['person1']['transl']-gender_seq_dict['person2']['transl']).unsqueeze(-1)).squeeze(-1)
                    rel_root_transl['b2a'] = torch.matmul((gender_seq_dict['person2']['transl']-gender_seq_dict['person1']['transl']), transf_rotmat['person1'])
                    rel_root_transl['a2b'] = torch.matmul((gender_seq_dict['person1']['transl']-gender_seq_dict['person2']['transl']), transf_rotmat['person2'])
                elif self.primitive_utility.feature_dim == 262:
                    rel_global_orient['b2a'], rel_global_orient['a2b'] = cal_rel_rot(gender_seq_dict['person1']['joints'].reshape(B, T, 22, 3), gender_seq_dict['person2']['joints'].reshape(B, T, 22, 3))
                    rel_root_transl['b2a'] = torch.matmul(gender_seq_dict['person2']['joints'].reshape(B, T, 22, 3)[:,:,0]-gender_seq_dict['person1']['joints'].reshape(B, T, 22, 3)[:,:,0], transf_rotmat['person1'])
                    rel_root_transl['a2b'] = torch.matmul(gender_seq_dict['person1']['joints'].reshape(B, T, 22, 3)[:,:,0]-gender_seq_dict['person2']['joints'].reshape(B, T, 22, 3)[:,:,0], transf_rotmat['person2']) 
                dists = torch.norm(gender_seq_dict['person1']['joints'].reshape(B, T, 22, 3).unsqueeze(3)-gender_seq_dict['person2']['joints'].reshape(B, T, 22, 3).unsqueeze(2), dim=-1)
                rel_mindis['b2a'], _ = dists.min(dim=-1)
                rel_mindis['a2b'], _ = dists.min(dim=-2)
                
                rel_info = {}
                for key in ['b2a', 'a2b']:
                    rel_info[key] = torch.cat([transforms.matrix_to_rotation_6d(rel_global_orient[key]), rel_root_transl[key], rel_mindis[key]], dim=-1)  # [B*num_mp, T, 6+3+22]

        # calc features
        feature_dict = {}
        for person in ['person1', 'person2']:
            feature_dict[person] = self.primitive_utility.calc_features(canonicalized_primitive_dict[person], use_predicted_joints=True)
            if self.primitive_utility.feature_dim == 276:
                feature_dict[person]['transl'] = feature_dict[person]['transl'][:, :-1, :]  # [B*num_mp, T, 3]
                feature_dict[person]['poses_6d'] = feature_dict[person]['poses_6d'][:, :-1, :]  # [B*num_mp, T, 66]
                feature_dict[person]['joints'] = feature_dict[person]['joints'][:, :-1, :]  # [B*num_mp, T, 22 * 3]

        
        data_batch = {} if self.mode == 'merged' else []
        if self.mode == 'merged':
            for person in ['person1', 'person2']:
                data_batch[person] = []
                motion_tensor_normalized = self.normalize(self.dict_to_tensor(feature_dict[person]))     # [B*num_mp, T, D]
                motion_tensor_normalized = motion_tensor_normalized.permute(0, 2, 1).unsqueeze(2)           # [B*num_mp, D, 1, T]
                history_mask = torch.zeros_like(motion_tensor_normalized, dtype=torch.bool, device=self.device)
                history_mask[..., :self.cfg.history_length] = True
                history_motion = torch.zeros_like(motion_tensor_normalized, dtype=torch.float32, device=self.device)
                history_motion[..., :self.cfg.history_length] = motion_tensor_normalized[..., :self.cfg.history_length]

                for primitive_idx in range(self.num_primitive):
                    start_idx = primitive_idx 
                    end_idx = primitive_idx + 1
                    if self.use_indi_text:
                        primitive_texts = gender_seq_texts[person][start_idx:end_idx]
                        if self.load_text_embedding:
                            unseen_texts = [text for text in primitive_texts if text not in self.text_embedding_dict[person]]
                            if len(unseen_texts) > 0:
                                self.update_text_embedding_dict(unseen_texts, person, text_sep=self.text_sep, max_segs=self.max_segs)
                            text_embedding = torch.stack([self.text_embedding_dict[person][text] for text in primitive_texts], dim=0)  # [B, 512]
                            if self.text_sep:
                                text_mask = torch.stack([self.text_mask_dict[person][text] for text in primitive_texts], dim=0)  # [B, max_segs]
                            else:
                                text_mask = None
            
                    data_batch[person].append({
                            'gender': [gender_seq_dict[person]['gender']],
                            'betas': gender_seq_dict[person]['betas'][start_idx:end_idx, :-1, :10],
                            'motion_tensor_normalized': motion_tensor_normalized[start_idx:end_idx, ...], # [B, D, 1, T]
                            'history_motion': history_motion[start_idx:end_idx, ...],
                            'history_mask': history_mask[start_idx:end_idx, ...],
                            'transf_rotmat': transf_rotmat[person][start_idx:end_idx, ...],
                            'transf_transl': transf_transl[person][start_idx:end_idx, ...],
                            'history_length': self.cfg.history_length,
                            'future_length': self.cfg.future_length,
                        })
                    if self.use_indi_text:
                        data_batch[person][-1]['texts'] = primitive_texts
                        if self.load_text_embedding:
                            data_batch[person][-1]['text_embedding'] = text_embedding
                            data_batch[person][-1]['text_mask'] = text_mask
                    if self.padding:
                        data_batch[person][-1]['primitive_padding_mask'] = gender_seq_dict[person]['primitive_padding_mask'][start_idx:end_idx, ...]
            data_batch['interaction'] = []
            for primitive_idx in range(self.num_primitive):        
                start_idx = primitive_idx
                end_idx = (primitive_idx + 1)
                primitive_texts = gender_seq_texts['interaction'][start_idx:end_idx]
                if self.load_text_embedding:
                    unseen_texts = [text for text in primitive_texts if text not in self.text_embedding_dict['interaction']]
                    if len(unseen_texts) > 0:
                        self.update_text_embedding_dict(unseen_texts, 'interaction', text_sep=self.text_sep, max_segs=self.max_segs)
                    text_embedding = torch.stack([self.text_embedding_dict['interaction'][text] for text in primitive_texts], dim=0)  # [B, 512]
                    if self.text_sep:
                        text_mask = torch.stack([self.text_mask_dict['interaction'][text] for text in primitive_texts], dim=0)  # [B, max_segs]
                    else:
                        text_mask = None
                data_batch['interaction'].append({
                        'texts': primitive_texts,
                        'rel_pose_b2a': rel_pose['b2a'][start_idx:end_idx],
                        'rel_pose_a2b': rel_pose['a2b'][start_idx:end_idx],
                    })
                if self.load_text_embedding:
                    data_batch['interaction'][-1]['text_embedding'] = text_embedding
                    data_batch['interaction'][-1]['text_mask'] = text_mask
                if self.use_interaction_model:
                    data_batch['interaction'][-1].update({
                        'rel_info_b2a': rel_info['b2a'][start_idx:end_idx, self.cfg.history_length:-1].reshape(1, -1),
                        'rel_info_a2b': rel_info['a2b'][start_idx:end_idx, self.cfg.history_length:-1].reshape(1, -1),
                    })
            return data_batch
        else:
            for person in ['person1', 'person2']:
                motion_tensor_normalized = self.normalize(self.dict_to_tensor(feature_dict[person]))     # [num_mp, T, D]
                motion_tensor_normalized = motion_tensor_normalized.permute(0, 2, 1).unsqueeze(2)           # [num_mp, D, 1, T]
                history_mask = torch.zeros_like(motion_tensor_normalized, dtype=torch.bool, device=self.device)
                history_mask[..., :self.cfg.history_length] = True
                history_motion = torch.zeros_like(motion_tensor_normalized, dtype=torch.float32, device=self.device)
                history_motion[..., :self.cfg.history_length] = motion_tensor_normalized[..., :self.cfg.history_length]

                for primitive_idx in range(self.num_primitive):
                    start_idx = primitive_idx
                    end_idx = primitive_idx + 1
                    if self.use_indi_text:
                        primitive_texts = gender_seq_texts[person][start_idx:end_idx]
                        if self.load_text_embedding:
                            unseen_texts = [text for text in primitive_texts if text not in self.text_embedding_dict[person]]
                            if len(unseen_texts) > 0:
                                new_text_embeddings = encode_text(self.clip_model, unseen_texts)
                                for idx, text in enumerate(unseen_texts):
                                    self.text_embedding_dict[person][text] = new_text_embeddings[idx]
                            text_embedding = torch.stack([self.text_embedding_dict[person][text] for text in primitive_texts], dim=0)  # [B, 512]
                            if self.text_sep:
                                text_mask = torch.stack([self.text_mask_dict[person][text] for text in primitive_texts], dim=0)  # [B, max_segs]
                            else:
                                text_mask = None
                    data_batch.append(
                        {
                            'gender': [gender_seq_dict[person]['gender']],
                            'betas': gender_seq_dict[person]['betas'][start_idx:end_idx, :-1, :10],
                            'motion_tensor_normalized': motion_tensor_normalized[start_idx:end_idx, ...], # [1, D, 1, T]
                            'history_motion': history_motion[start_idx:end_idx, ...],
                            'history_mask': history_mask[start_idx:end_idx, ...],
                            'history_length': self.cfg.history_length,
                            'future_length': self.cfg.future_length,
                        }
                    )
                    if self.use_indi_text:
                        data_batch[-1]['texts'] = primitive_texts
                        if self.load_text_embedding:
                            data_batch[-1]['text_embedding'] = text_embedding
                            data_batch[-1]['text_mask'] = text_mask
                            
                    if self.padding:
                        data_batch[-1]['primitive_padding_mask'] = gender_seq_dict[person]['primitive_padding_mask'][start_idx:end_idx, ...]
            if random.random() < 0.5:
                return data_batch[:self.num_primitive]
            else:
                return data_batch[self.num_primitive:]


class InterGenDatasetWPE(WeightedPrimitiveSequenceDataset):
    def __init__(self, dataset_name='interhuman',
                 dataset_path='./data/InterHuman/seq_data_single_interaction_d262_fps30_mirror_exchangeyz',
                 cfg_path='./config_files/config_hydra/motion_primitive/interhuman_h2_f8_r4.yaml',
                 split="train",
                 device='cuda',
                 weight_scheme='uniform',
                 prob_static=0.0,
                 enforce_gender=None,
                 enforce_zero_beta=None,
                 load_data=True,
                 text_tolerance=0.0,
                 body_type='smplh',
                 seed_only=False,
                 use_frame_weights=True,
                 mode='merged', # 'sep' or 'merged'
                 text_sep = False,
                 max_segs = 20,
                 motion_repr = {'joints': 22 * 3,
                    'joints_delta': 22 * 3,
                    'body_pose': 21 * 6,
                    'feet_contact': 4,},
                 **kwargs):
        self.dataset_name = dataset_name
        self.dataset_path = dataset_path
        self.split = split
        self.device = device
        self.weight_scheme = weight_scheme
        self.prob_static = prob_static
        self.enforce_gender = enforce_gender
        self.enforce_zero_beta = enforce_zero_beta
        print('enforce_gender: ', enforce_gender)
        print('enforce_zero_beta: ', enforce_zero_beta)
        
        self.text_tolerance = text_tolerance
        self.seed_only = seed_only
        self.mode = mode
        self.text_sep = text_sep
        self.max_segs = max_segs
        self.sep_mode = kwargs.get('sep_mode', 0)
        self.padding = kwargs.get('padding', False)
        self.normalize_relpose = kwargs.get('normalize_relpose', False)
        self.use_interaction_model = kwargs.get('use_interaction_model', False)
        self.key_list = ['person1', 'person2', 'interaction'] if self.mode=='merged' else ['person1', 'person2']
        self.feet_thre = 0.001
        self.n_joints = 22
        
        self.clip_version = kwargs.get('clip_version', 'ViT-B/32')
        self.load_text_embedding = kwargs.get('load_text_embedding', False)
        self.use_indi_text = kwargs.get('use_indi_text', False)
        self.text_key_list = ['person1', 'person2', 'interaction'] if (self.use_indi_text or self.mode=='sep') else ['interaction']
        
        self.primitive_utility = PrimitiveUtility(device=self.device, body_type=body_type, motion_repr=motion_repr)
        self.motion_repr = self.primitive_utility.motion_repr

        # cfg_path = Path(dataset_path, 'config.yaml')
        with open(cfg_path, 'r') as f:
            self.cfg = OmegaConf.load(f)
        self.target_fps = self.cfg.fps
        # self.downsample_rate = 120 // self.target_fps
        self.history_length = self.cfg.history_length
        self.future_length = self.cfg.future_length
        self.primitive_length = self.history_length + self.future_length
        self.num_primitive = self.cfg.num_primitive
        if seed_only:
            assert self.num_primitive == 1
        self.seq_length = self.history_length + self.future_length * self.num_primitive + 1
        self.min_length = self.history_length + self.future_length + 1

        if load_data:
            with open(pjoin(dataset_path, f'{split}.pkl'), 'rb') as f:
                dataset = pickle.load(f)
            if not self.padding:
                dataset = [data for data in dataset if len(data['motion_p1']['trans']) >= self.seq_length]
            
            elements_to_remove = ['7220', '7221', '6028', '7543', '6940', '4434', '7561', '4385']
            dataset = [data for data in dataset if (data['seq_name'] not in elements_to_remove and data['motion_p1']['trans'].shape[0]>=self.min_length)]
            
            # dataset = [data for data in dataset if data['seq_name'] == '7662']

            for data in dataset:
                if self.padding:
                    T = data['motion_p1']['trans'].shape[0]
                    if T < self.seq_length:
                        pad_len = self.seq_length - T
                        for person in ['motion_p1', 'motion_p2']:
                            for key in ['trans', 'global_orient', 'pose_body', 'joints']:
                                last_frame = data[person][key][-1:]
                                padding = np.repeat(last_frame, pad_len, axis=0)
                                data[person][key] = np.concatenate([
                                    data[person][key],
                                    padding
                                ], axis=0)

                            # padding_mask
                            data[person]['padding_mask'] = np.concatenate([
                                np.zeros(T, dtype=np.bool_),
                                np.ones(pad_len, dtype=np.bool_)
                            ], axis=0)
                    else:
                        for person in ['motion_p1', 'motion_p2']:
                            data[person]['padding_mask'] = np.zeros(T, dtype=np.bool_)
                        
                def convert_motion(motion, gender, enforce_zero_beta):
                    betas = torch.from_numpy(motion['betas'].astype(np.float32))
                    if enforce_zero_beta:
                        betas = torch.zeros_like(betas)
                    transl = torch.from_numpy(motion['trans'].astype(np.float32))
                    global_orient = transforms.axis_angle_to_matrix(torch.from_numpy(motion['global_orient'].astype(np.float32)))
                    body_pose = torch.from_numpy(motion['pose_body'].astype(np.float32)).reshape(-1, 21, 6) # [T, 21, 6]            
                    pelvis_delta = torch.from_numpy(motion['pelvis_delta'].astype(np.float32))              # [3]
                    joints = torch.from_numpy(motion['joints'].astype(np.float32))                          # [T, 22, 3]
                    result = {
                        'gender': gender,
                        'betas': betas,
                        'transl': transl,
                        'global_orient': global_orient,
                        'body_pose': body_pose,
                        'pelvis_delta': pelvis_delta,
                        'joints': joints,
                    }
                    if self.padding:
                        result['padding_mask'] = motion['padding_mask']
                        
                    return result
                gender_p1 = self.enforce_gender if self.enforce_gender is not None else data['motion_p1']['gender']
                gender_p2 = self.enforce_gender if self.enforce_gender is not None else data['motion_p2']['gender']
                data['motion_p1'] = convert_motion(data['motion_p1'], gender_p1, self.enforce_zero_beta)
                data['motion_p2'] = convert_motion(data['motion_p2'], gender_p2, self.enforce_zero_beta)
            
            print('num of sequences: ', len(dataset))
            
            # assign sampling weights to each sequence
            for data in dataset:
                if 'uniform' in weight_scheme:
                    data['weight'] = 1.0
                elif 'length' in weight_scheme:
                    if self.padding and len(data['motion_p1']['transl'])==self.seq_length:
                        data['weight'] = len(data['motion_p1']['transl'])-sum(data['motion_p1']['padding_mask'])
                    else:
                        data['weight'] = len(data['motion_p1']['transl'])
            print('finish first assigning seq weights')

            # overfit using one sequence
            # if 'overfit' in weight_scheme:
            #     seq_id = int(weight_scheme.split('overfit:')[-1].split('_')[0])
            #     for idx, data in enumerate(dataset):
            #         if idx == seq_id:
            #             data['weight'] = 1.0
            #         else:
            #             data['weight'] = 0.0

            seq_weights = np.array([data['weight'] for data in dataset])
            seq_weights = seq_weights / seq_weights.sum()

            self.dataset = dataset
            self.seq_weights = seq_weights
        self._curr_test_index = 0
        
        # load or calc mean and std
        self.tensor_mean_device_dict = {}
        file_name = f'mean_std_h{self.history_length}_f{self.future_length}'
        
        suffix = '_padding' if self.padding else ''
        mean_std_path = Path(dataset_path, f'{file_name}{suffix}.pkl')
        mean_std_relpose_path = Path(dataset_path, f'{file_name}_relpose{suffix}.pkl')
        if self.use_interaction_model:
            mean_std_interaction_path = Path(dataset_path, f'{file_name}_interaction{suffix}.pkl')
        
        if mean_std_path.exists() and (not self.use_interaction_model or mean_std_interaction_path.exists()) and (not self.normalize_relpose or mean_std_relpose_path.exists()):
            print(f'loading mean and std from {mean_std_path}')
            with open(mean_std_path, 'rb') as f:
                self.tensor_mean, self.tensor_std = pickle.load(f)  # [1, 1, D]

            if self.normalize_relpose:
                print(f'loading relpose mean and std from {mean_std_relpose_path}')
                with open(mean_std_relpose_path, 'rb') as f:
                    self.relpose_mean, self.relpose_std = pickle.load(f)
            
            if self.use_interaction_model:
                print(f'loading interaction mean and std from {mean_std_interaction_path}')
                with open(mean_std_interaction_path, 'rb') as f:
                    self.rel_mean, self.rel_std = pickle.load(f)  # [1, 1, D]
        else:
            assert self.split == 'train'
            print('calculating mean and std using train split')
            result = self.calc_mean_std()

            if self.use_interaction_model:
                self.tensor_mean, self.tensor_std, self.rel_mean, self.rel_std = result
            elif self.normalize_relpose:
                self.tensor_mean, self.tensor_std, self.relpose_mean, self.relpose_std = result
            else:
                self.tensor_mean, self.tensor_std = result
                self.rel_mean, self.rel_std = None, None

            with open(mean_std_path, 'wb') as f:
                pickle.dump((self.tensor_mean.detach().cpu(), self.tensor_std.detach().cpu()), f)
            if self.normalize_relpose:
                with open(mean_std_relpose_path, 'wb') as f:
                    pickle.dump((self.relpose_mean.detach().cpu(), self.relpose_std.detach().cpu()), f)
            
            if self.use_interaction_model:
                with open(mean_std_interaction_path, 'wb') as f:
                    pickle.dump((self.rel_mean.detach().cpu(), self.rel_std.detach().cpu()), f)
    
        # load clip model, get train text embeddings
        if self.load_text_embedding:
            self.load_and_freeze_clip(clip_version=self.clip_version, device=self.device)
            self.dim_embed_text = self.clip_model.ln_final.normalized_shape[0]
            suffix = '' if self.clip_version == 'ViT-B/32' else f"_{self.clip_version.replace('/', '')}"
            self.embedding_path = {}
            embedding_path = {}
            for key_type in self.text_key_list:
                if text_sep:
                    self.embedding_path[key_type] = embedding_path[key_type] = Path(dataset_path, f'{split}_{key_type}_text_embedding_dict_textsep_sepmode{self.sep_mode}{suffix}.pkl')
                else:
                    self.embedding_path[key_type] = embedding_path[key_type] = Path(dataset_path, f'{split}_{key_type}_text_embedding_dict{suffix}.pkl')
            self.text_embedding_dict = {}
            if text_sep:
                self.text_mask_dict = {}
        
            for key_type in self.text_key_list:
                if embedding_path[key_type].exists():
                    print(f"Loading text_{key_type} embeddings from {embedding_path[key_type]}!")
                    with open(embedding_path[key_type], 'rb') as f:
                        self.text_embedding_dict[key_type] = pickle.load(f)
                    if text_sep:
                        with open(Path(str(embedding_path[key_type]).replace('text_embedding_dict', 'text_mask_dict')), 'rb') as f:
                            self.text_mask_dict[key_type] = pickle.load(f)
                else:
                    print('Calculating text embeddings')
                    raw_texts = []
                    for data in self.dataset:
                        if f'frame_labels_{key_type}' in data:
                            raw_texts.extend([seg['proc_label'] for seg in data['frame_labels_' + key_type]])

                    raw_texts = list(set(raw_texts))
                    num_texts = len(raw_texts)
                    print(f'num of unique texts_{key_type}: ', len(raw_texts))
                        
                    # get text embeddings by batch
                    text_embeddings = []
                    text_mask = []
                    batch_start_idx = 0
                    while batch_start_idx < num_texts:
                        batch_end_idx = min(batch_start_idx + 256, num_texts)
                        text_embeddings_temp = self.encode_text(raw_texts[batch_start_idx:batch_end_idx], text_sep=text_sep, max_segs=max_segs, sep_mode=self.sep_mode)
                        if text_sep:
                            text_embeddings.append(text_embeddings_temp[0])
                            text_mask.append(text_embeddings_temp[1])
                        else:
                            text_embeddings.append(text_embeddings_temp)
                        batch_start_idx = batch_end_idx
                    text_embeddings = torch.cat(text_embeddings, dim=0).detach().cpu().numpy()
                
                    self.text_embedding_dict[key_type] = {raw_texts[idx]: text_embeddings[idx] for idx in range(num_texts)}
                    if text_sep:
                        self.text_embedding_dict[key_type][''] = np.zeros((self.max_segs, self.dim_embed_text)).astype(np.float32)
                    else:
                        self.text_embedding_dict[key_type][''] = np.zeros(self.dim_embed_text).astype(np.float32)  # for empty text have zero embedding, compatible with mdm text masking
                    with open(embedding_path[key_type], 'wb') as f:
                        pickle.dump(self.text_embedding_dict[key_type], f)
                    if text_sep:
                        text_mask = torch.cat(text_mask, dim=0).detach().cpu().numpy()
                        self.text_mask_dict[key_type] = {raw_texts[idx]: text_mask[idx] for idx in range(num_texts)}
                        self.text_mask_dict[key_type][''] = np.zeros(max_segs).astype(np.bool_)  # for empty text have zero embedding, compatible with mdm text masking
                        with open(Path(str(embedding_path[key_type]).replace('text_embedding_dict', 'text_mask_dict')), 'wb') as f:
                            pickle.dump(self.text_mask_dict[key_type], f)
                
                for key in self.text_embedding_dict[key_type]:
                    self.text_embedding_dict[key_type][key] = torch.from_numpy(self.text_embedding_dict[key_type][key]).to(dtype=torch.float32, device=self.device)
                    if text_sep:
                        self.text_mask_dict[key_type][key] = torch.from_numpy(self.text_mask_dict[key_type][key]).to(dtype=torch.bool, device=self.device)

    def load_and_freeze_clip(self, clip_version, device='cpu'):
        self.clip_model, _= clip.load(clip_version, device=device,
                                                jit=False)  # Must set jit=False for training
        clip.model.convert_weights(self.clip_model)  # Actually this line is unnecessary since clip by default already on float16

        # Freeze CLIP weights
        self.clip_model.eval()
        for p in self.clip_model.parameters():
            p.requires_grad = False
    
    def encode_text(self, raw_text, force_empty_zero=True, text_sep=False, max_segs = 20, sep_mode=0):
        import pandas as pd
        device = next(self.clip_model.parameters()).device
        embed_dim = self.dim_embed_text
        batch_size = len(raw_text)

        if not text_sep:
            with torch.no_grad():
                texts = clip.tokenize(raw_text, truncate=True).to(device)  # [B, context_length]
                text_embedding = self.clip_model.encode_text(texts).float()  # [B, 512]
                if force_empty_zero:
                    empty_text = [t == '' for t in raw_text]
                    text_embedding[empty_text, :] = 0
                return text_embedding
                
        raw_series = pd.Series(raw_text).str.strip().str.rstrip('.')
        if sep_mode == 0:
            split_df = raw_series.str.split(r'[,.]', n=max_segs - 1, expand=True)
        elif sep_mode == 1:
            split_df = raw_series.str.split(r'\band\b|\bwhile\b|,|\.', n=max_segs - 1, expand=True)
        split_df = split_df.fillna('').astype(str).applymap(str.strip)

        split_df = split_df.reindex(columns=range(max_segs), fill_value='')
        
        segs_matrix = split_df.values
        segs_flat = segs_matrix.reshape(-1).tolist()

        text_mask = (segs_matrix == '').astype(bool)
        text_mask = torch.tensor(text_mask, dtype=torch.bool, device=device)

        tokenized = clip.tokenize(segs_flat, truncate=True).to(device)  # [B*max_segs, context_length]
        text_embedding = self.clip_model.encode_text(tokenized).float()      # [B*max_segs, 512]
        text_embedding = text_embedding.view(batch_size, max_segs, embed_dim)  # [B, max_segs, 512]

        if force_empty_zero:
            text_embedding[text_mask] = 0

        return text_embedding, text_mask
    
    def get_batch_idx(self, batch_size=8):
        if self.split == 'test':
            start_idx = self._curr_test_index
            end_idx = start_idx + batch_size
            batch_idx = np.arange(start_idx, min(end_idx, len(self.dataset)))
            self._curr_test_index = end_idx if end_idx < len(self.dataset) else 0
            return batch_idx
        else:
            batch_idx = np.random.choice(len(self.dataset), size=batch_size, replace=True, p=self.seq_weights)
            return batch_idx

    def update_text_embedding_dict(self, new_texts, key_type, text_sep=False, max_segs=20):
        new_text_embeddings = self.encode_text(new_texts, text_sep=text_sep, max_segs=max_segs, sep_mode=self.sep_mode)
        for idx, text in enumerate(new_texts):
            if text_sep:
                self.text_embedding_dict[key_type][text] = new_text_embeddings[0][idx]
                self.text_mask_dict[key_type][text] = new_text_embeddings[1][idx]
            else:
                self.text_embedding_dict[key_type][text] = new_text_embeddings[idx]

    def calc_mean_std(self, batch_size=512):
        if self.future_length == 1:
            batch_size = min(batch_size, 64)
        all_mp_data, all_rel_info = [], []
        for seq_data in self.dataset:
            motion_data_p1 = seq_data['motion_p1']
            num_frames = motion_data_p1['body_pose'].shape[0]
            primitive_data_list = []
            for start_frame in range(0, num_frames - self.primitive_length, self.future_length):
                end_frame = start_frame + self.primitive_length
                primitive_data_list.append(self.get_primitive(seq_data, start_frame, end_frame, skip_text=True))
                
            primitive_dict = {}
            for person in ['person1', 'person2']:
                primitive_dict[person] = {}
                primitive_dict[person]['gender'] = {primitive_data_list[0]['primitive_dict'][person]['gender']}
                for key in ['betas', 'transl', 'global_orient', 'body_pose', 'transf_rotmat', 'transf_transl', 'pelvis_delta', 'joints']:
                    primitive_dict[person][key] = torch.cat([data['primitive_dict'][person][key] for data in primitive_data_list], dim=0)
                if self.padding:
                    primitive_dict[person]['primitive_padding_mask'] = torch.cat([data['primitive_dict'][person]['primitive_padding_mask'] for data in primitive_data_list], dim=0)
                primitive_dict[person] = tensor_dict_to_device(primitive_dict[person], self.device)

            # split primitive_dict into batches
            batch_start_idx = 0
            while batch_start_idx < len(primitive_dict['person1']['body_pose']):
                batch_primitive_dict = {}
                canonicalized_primitive_dict = {}
                if self.use_interaction_model:
                    transf_rotmat, transf_transl = {}, {}
                batch_end_idx = min(batch_start_idx + batch_size, len(primitive_dict['person1']['body_pose']))
                for person in ['person1', 'person2']:
                    batch_primitive_dict[person] = {}
                    batch_primitive_dict[person] = {key: primitive_dict[person][key][batch_start_idx:batch_end_idx] for key in ['betas', 'transl', 'global_orient', 'body_pose', 'transf_rotmat', 'transf_transl', 'pelvis_delta', 'joints']}
                    batch_primitive_dict[person]['gender'] = primitive_dict[person]['gender']
                    if self.use_interaction_model:
                        transf_rotmat[person], transf_transl[person], canonicalized_primitive_dict[person] = self.primitive_utility.canonicalize(copy.deepcopy(batch_primitive_dict[person]), use_predicted_joints=True)
                    else:
                        _, _, canonicalized_primitive_dict[person] = self.primitive_utility.canonicalize(copy.deepcopy(batch_primitive_dict[person]), use_predicted_joints=True)
                
                feature_dict = {}
                for person in ['person1', 'person2']:
                    feature_dict[person] = self.primitive_utility.calc_features(canonicalized_primitive_dict[person], use_predicted_joints=True)
                    if self.primitive_utility.feature_dim == 276:
                        feature_dict[person]['transl'] = feature_dict[person]['transl'][:, :-1, :]  # [B*num_mp, T, 3]
                        feature_dict[person]['poses_6d'] = feature_dict[person]['poses_6d'][:, :-1, :]  # [B*num_mp, T, 66]
                        feature_dict[person]['joints'] = feature_dict[person]['joints'][:, :-1, :]  # [B*num_mp, T, 22 * 3]
                    motion_tensor = self.dict_to_tensor(feature_dict[person]).detach().cpu()    # [num_primitive, T, D]
                    if self.padding:
                        mask_slice = primitive_dict[person]['primitive_padding_mask'][batch_start_idx:batch_end_idx, -1]  # [B]
                        valid_indices = torch.nonzero(~mask_slice, as_tuple=True)[0].detach().cpu()  # select valid
                        motion_tensor = motion_tensor[valid_indices]
                    all_mp_data.append(motion_tensor)

                if self.use_interaction_model:
                    B, T, *_ = batch_primitive_dict['person1']['joints'].shape
                    joints_p1 = batch_primitive_dict['person1']['joints'].reshape(B, T, 22, 3)  # [B, T+1, 22, 3]
                    joints_p2 = batch_primitive_dict['person2']['joints'].reshape(B, T, 22, 3)
                    # reltive transition, relative distance
                    rel_global_orient, rel_root_transl, rel_mindis = {}, {}, {}
                    if self.primitive_utility.feature_dim == 276:
                        rel_global_orient['b2a'] = batch_primitive_dict['person1']['global_orient'].transpose(-1, -2) @ batch_primitive_dict['person2']['global_orient']
                        rel_global_orient['a2b'] = batch_primitive_dict['person2']['global_orient'].transpose(-1, -2) @ batch_primitive_dict['person1']['global_orient']
                        # rel_root_transl['b2a'] = torch.matmul(batch_primitive_dict['person1']['global_orient'].transpose(-1, -2), (batch_primitive_dict['person2']['transl']-batch_primitive_dict['person1']['transl']).unsqueeze(-1)).squeeze(-1)
                        # rel_root_transl['a2b'] = torch.matmul(batch_primitive_dict['person2']['global_orient'].transpose(-1, -2), (batch_primitive_dict['person1']['transl']-batch_primitive_dict['person2']['transl']).unsqueeze(-1)).squeeze(-1)
                        rel_root_transl['b2a'] = torch.matmul(batch_primitive_dict['person2']['transl']-batch_primitive_dict['person1']['transl'], transf_rotmat['person1'])
                        rel_root_transl['a2b'] = torch.matmul(batch_primitive_dict['person1']['transl']-batch_primitive_dict['person2']['transl'], transf_rotmat['person2'])
                    elif self.primitive_utility.feature_dim == 262:
                        rel_global_orient['b2a'], rel_global_orient['a2b'] = cal_rel_rot(joints_p1, joints_p2)
                        rel_root_transl['b2a'] = torch.matmul(joints_p2[:,:,0]-joints_p1[:,:,0], transf_rotmat['person1'])
                        rel_root_transl['a2b'] = torch.matmul(joints_p1[:,:,0]-joints_p2[:,:,0], transf_rotmat['person2']) 

                    dists = torch.norm(joints_p1.unsqueeze(3) - joints_p2.unsqueeze(2), dim=-1)  # [B, T+1, 22, 22]
                    rel_mindis = {
                        'b2a': dists.min(dim=-1).values,  # [B, T+1, 22]
                        'a2b': dists.min(dim=-2).values,
                    }

                    for key in ['b2a', 'a2b']:
                        rot_6d = transforms.matrix_to_rotation_6d(rel_global_orient[key])  # [B, T+1, 6]
                        rel_info = torch.cat([rot_6d, rel_root_transl[key], rel_mindis[key]], dim=-1)  # [B, T+1, 6+3+22]
                        rel_info = rel_info[:, self.cfg.history_length:-1, :].detach().cpu()  # [B, T-(1+history_length), D]
                        if self.padding:
                            rel_info = rel_info[valid_indices]
                        all_rel_info.append(rel_info)
                batch_start_idx = batch_end_idx

        all_mp_data = torch.cat(all_mp_data, dim=0)                 # [2*N, T, D]
        tensor_mean = all_mp_data.mean(dim=[0, 1], keepdim=True)    # [1, 1, D]
        tensor_std = all_mp_data.std(dim=[0, 1], keepdim=True)      # [1, 1, D]
        if self.use_interaction_model and len(all_rel_info) > 0:
            all_rel_info = torch.cat(all_rel_info, dim=0)  # [2N, (T-(1+history_length)), D]
            rel_mean = all_rel_info.mean(dim=[0, 1], keepdim=True)
            rel_std = all_rel_info.std(dim=[0, 1], keepdim=True)
            relpose_mean = rel_mean[...,:9]
            relpose_std = rel_std[...,:9]
        if self.normalize_relpose:
            return tensor_mean.to(self.device), tensor_std.to(self.device), relpose_mean.to(self.device), relpose_std.to(self.device)
        if self.use_interaction_model:
            return tensor_mean.to(self.device), tensor_std.to(self.device), rel_mean.to(self.device), rel_std.to(self.device)
        return tensor_mean.to(self.device), tensor_std.to(self.device)

    def get_primitive(self, seq_data, start_frame, end_frame, skip_text=False):
        """end_frame included"""
        primitive_dict = {}
        for person, motion_data in zip(['person1', 'person2'], [seq_data['motion_p1'], seq_data['motion_p2']]):
            primitive_dict[person] = {
                'gender': motion_data['gender'],
                'betas': motion_data['betas'].expand(1, self.primitive_length + 1, 10),
                'transl': motion_data['transl'][start_frame:end_frame + 1].unsqueeze(0),
                'global_orient': motion_data['global_orient'][start_frame:end_frame + 1].unsqueeze(0),
                'body_pose': motion_data['body_pose'][start_frame:end_frame + 1].unsqueeze(0),
                'pelvis_delta': motion_data['pelvis_delta'].unsqueeze(0),
                'joints': motion_data['joints'][start_frame:end_frame + 1].unsqueeze(0),
                'transf_rotmat': torch.eye(3).unsqueeze(0),
                'transf_transl': torch.zeros(1, 1, 3),
            }
            if self.padding:
                padding_mask_full = seq_data[f'motion_p{person[-1]}']['padding_mask'][start_frame:end_frame + 1]  # shape [T+1]
                history_mask = torch.tensor(padding_mask_full[:self.history_length], dtype=torch.bool)
                future_mask = padding_mask_full[self.history_length:-1]
                future_flag = torch.tensor(future_mask.any(), dtype=torch.bool)
                primitive_dict[person]['primitive_padding_mask'] = torch.cat([history_mask, future_flag.unsqueeze(0)], dim=0).unsqueeze(0) # (1, history_length + 1)
        
        texts = {key: [] for key in self.text_key_list}
        for key_type in self.text_key_list:
            if not skip_text and f'frame_labels_{key_type}' in seq_data:
                future_start = (start_frame + self.history_length) / self.target_fps
                future_end = (start_frame + self.history_length + self.future_length - 1) / self.target_fps
                # print('text tolerance: ', self.text_tolerance)
                for seg in seq_data[f'frame_labels_{key_type}']:
                    if have_overlap([seg['start_t'], seg['end_t']], [future_start - self.text_tolerance, future_end + self.text_tolerance]):
                        texts[key_type].append(seg['proc_label'])

        output = {}
        for key_type in self.text_key_list:
            output['text_'+key_type] = random.choice(texts[key_type]) if len(texts[key_type]) > 0 else ''
        output['primitive_dict'] = primitive_dict
        return output

    def get_relpose_mean_std_by_device(self, device):
        if not hasattr(self, 'relpose_mean_device_dict'):
            self.relpose_mean_device_dict = {}

        if device not in self.relpose_mean_device_dict:
            assert self.relpose_mean is not None and self.relpose_std is not None, "rel_mean/std must be computed before normalization."
            self.relpose_mean_device_dict[device] = (
                self.relpose_mean.to(device=device),
                self.relpose_std.to(device=device)
            )
        return self.relpose_mean_device_dict[device]

    def get_rel_mean_std_by_device(self, device):
        if not hasattr(self, 'rel_mean_device_dict'):
            self.rel_mean_device_dict = {}

        if device not in self.rel_mean_device_dict:
            assert self.rel_mean is not None and self.rel_std is not None, "rel_mean/std must be computed before normalization."
            self.rel_mean_device_dict[device] = (
                self.rel_mean.to(device=device),
                self.rel_std.to(device=device)
            )
        return self.rel_mean_device_dict[device]

    def normalize_rel_pose(self, rel_pose: torch.Tensor) -> torch.Tensor:
        """
        Standardize interaction feature tensor using rel_mean / rel_std
        rel_pose: Tensor of shape [B, D] or [B, T, D]
        """
        relpose_mean, relpose_std = self.get_relpose_mean_std_by_device(rel_pose.device)
        relpose_std_safe = relpose_std.clone()
        relpose_std_safe[relpose_std_safe == 0] = 1.0  # avoid division by zero
        return (rel_pose - relpose_mean) / relpose_std_safe

    def normalize_rel_info(self, rel_info: torch.Tensor) -> torch.Tensor:
        """
        Standardize interaction feature tensor using rel_mean / rel_std
        rel_info: Tensor of shape [B, D] or [B, T, D]
        """
        rel_mean, rel_std = self.get_rel_mean_std_by_device(rel_info.device)
        rel_std_safe = rel_std.clone()
        rel_std_safe[rel_std_safe == 0] = 1.0  # avoid division by zero
        return (rel_info - rel_mean) / rel_std_safe

    def get_batch(self, batch_size=8):
        self.time = time.time()
        seq_list = []
        batch_idx = self.get_batch_idx(batch_size)
        add_key_list = ['gender']
        cat_key_list = ['betas', 'motion_tensor_normalized', 'history_motion', 'history_mask', 'transf_rotmat', 'transf_transl', 'start_frame', 'total_frames']
        if self.padding:
            cat_key_list.append('primitive_padding_mask')
        if self.use_indi_text:
            add_key_list.append('texts')
            if self.load_text_embedding:
                cat_key_list.append('text_embedding')
            if self.text_sep:
                cat_key_list.append('text_mask')
        
        for seq_idx in batch_idx:
            seq_data = dict(self.dataset[seq_idx])
            # exchange person1 and person2
            if random.random() < 0.5:
                seq_data['motion_p1'], seq_data['motion_p2'] = seq_data['motion_p2'], seq_data['motion_p1']
                if self.use_indi_text:
                    seq_data['frame_labels_person1'], seq_data['frame_labels_person2'] = seq_data['frame_labels_person2'], seq_data['frame_labels_person1']
            num_frames = len(seq_data['motion_p1']['transl'])
            if 'text' in self.weight_scheme:
                start_frame = random.choices(range(num_frames - self.seq_length + 1), weights=seq_data['frame_weights'], k=1)[0]
            else:
                start_frame = random.randint(0, num_frames - self.seq_length)  # [0, num_frames - seq_length], right end inclusive
            primitive_data_list = []
            for frame_idx in range(start_frame, start_frame + self.seq_length - self.primitive_length, self.future_length):
                primitive_data = self.get_primitive(seq_data, frame_idx, frame_idx + self.primitive_length)
                for person in ['person1', 'person2']:
                    primitive_data['primitive_dict'][person]['start_frame'] = torch.tensor(frame_idx).view(1)
                    primitive_data['primitive_dict'][person]['total_frames'] = torch.tensor(num_frames).view(1)
                primitive_data_list.append(primitive_data)
            seq_list.append(primitive_data_list)

        # sort batch by gender
        batch = None
        for gender in ['female', 'male', 'neutral']:
            gender_idx = [idx for idx in range(len(seq_list)) if seq_list[idx][0]['primitive_dict']['person1']['gender'] == gender]
            if len(gender_idx) == 0:
                continue
            gender_seq_list = [seq_list[i] for i in gender_idx]
            gender_batch_size = len(gender_idx)
            gender_batch = {} if self.mode == 'merged' else []
            
            gender_seq_texts = {key_type: None for key_type in self.text_key_list}
            gender_seq_dict = None
            for primitive_idx in range(self.num_primitive):
                primitive_dict = {}
                for person in ['person1', 'person2']:
                    primitive_dict[person] = {}
                    primitive_dict[person]['gender'] = gender
                    for key in ['betas', 'transl', 'global_orient', 'body_pose', 'transf_rotmat', 'transf_transl', 'pelvis_delta', 'joints', 'start_frame', 'total_frames']:
                        primitive_dict[person][key] = torch.cat([mp_seq[primitive_idx]['primitive_dict'][person][key] for mp_seq in gender_seq_list], dim=0)
                    if self.padding:
                        primitive_dict[person]['primitive_padding_mask'] = torch.cat([mp_seq[primitive_idx]['primitive_dict'][person]['primitive_padding_mask'] for mp_seq in gender_seq_list], dim=0)
                primitive_texts = {}
                for key_type in self.text_key_list:
                    primitive_texts[key_type] = [mp_seq[primitive_idx]['text_'+key_type] for mp_seq in gender_seq_list]
                    gender_seq_texts[key_type] = primitive_texts[key_type] if gender_seq_texts[key_type] is None else gender_seq_texts[key_type] + primitive_texts[key_type]
                
                if gender_seq_dict is None:
                    gender_seq_dict = primitive_dict
                else:
                    for person in ['person1', 'person2']:
                        for key in ['betas', 'transl', 'global_orient', 'body_pose', 'transf_rotmat', 'transf_transl', 'pelvis_delta', 'joints', 'start_frame', 'total_frames']:
                            gender_seq_dict[person][key] = torch.cat([gender_seq_dict[person][key], primitive_dict[person][key]], dim=0)
                        if self.padding:
                            gender_seq_dict[person]['primitive_padding_mask'] = torch.cat([gender_seq_dict[person]['primitive_padding_mask'], 
                                                                                           primitive_dict[person]['primitive_padding_mask']], dim=0)

            canonicalized_primitive_dict = {}
            transf_rotmat, transf_transl = {}, {}
            for person in ['person1', 'person2']:
                gender_seq_dict[person] = tensor_dict_to_device(gender_seq_dict[person], self.device)
                transf_rotmat[person], transf_transl[person], canonicalized_primitive_dict[person] = self.primitive_utility.canonicalize(copy.deepcopy(gender_seq_dict[person]), use_predicted_joints=True)            
            if self.mode == 'merged':
                # rel_pose
                rel_rotmat, rel_transl, rel_pose = {}, {}, {}
                rel_rotmat['b2a'], rel_transl['b2a'] = self.primitive_utility.compute_rel_transform_B_in_A(transf_rotmat['person1'], transf_transl['person1'], transf_rotmat['person2'], transf_transl['person2'])
                rel_rotmat['a2b'], rel_transl['a2b'] = self.primitive_utility.compute_rel_transform_B_in_A(transf_rotmat['person2'], transf_transl['person2'], transf_rotmat['person1'], transf_transl['person1'])
                for rel in ['b2a', 'a2b']:
                    rel_rotmat[rel] = transforms.matrix_to_rotation_6d(rel_rotmat[rel])
                    rel_pose[rel] = torch.cat([rel_rotmat[rel], rel_transl[rel].squeeze(1)], dim=-1)  # [B*num_mp, 6+3]
            
                if self.use_interaction_model:
                    B, T, *_ = gender_seq_dict['person1']['joints'].shape
                    # reltive transition, relative distance
                    rel_global_orient, rel_root_transl, rel_mindis = {}, {}, {}
                    if self.primitive_utility.feature_dim == 276:
                        rel_global_orient['b2a'] = gender_seq_dict['person1']['global_orient'].transpose(-1, -2) @ gender_seq_dict['person2']['global_orient']
                        rel_global_orient['a2b'] = gender_seq_dict['person2']['global_orient'].transpose(-1, -2) @ gender_seq_dict['person1']['global_orient']
                        # rel_root_transl['b2a'] = torch.matmul(gender_seq_dict['person1']['global_orient'].transpose(-1, -2), (gender_seq_dict['person2']['transl']-gender_seq_dict['person1']['transl']).unsqueeze(-1)).squeeze(-1)
                        # rel_root_transl['a2b'] = torch.matmul(gender_seq_dict['person2']['global_orient'].transpose(-1, -2), (gender_seq_dict['person1']['transl']-gender_seq_dict['person2']['transl']).unsqueeze(-1)).squeeze(-1)
                        rel_root_transl['b2a'] = torch.matmul((gender_seq_dict['person2']['transl']-gender_seq_dict['person1']['transl']), transf_rotmat['person1'])
                        rel_root_transl['a2b'] = torch.matmul((gender_seq_dict['person1']['transl']-gender_seq_dict['person2']['transl']), transf_rotmat['person2'])
                    elif self.primitive_utility.feature_dim == 262:
                        rel_global_orient['b2a'], rel_global_orient['a2b'] = cal_rel_rot(gender_seq_dict['person1']['joints'].reshape(B, T, 22, 3), gender_seq_dict['person2']['joints'].reshape(B, T, 22, 3))
                        rel_root_transl['b2a'] = torch.matmul(gender_seq_dict['person2']['joints'].reshape(B, T, 22, 3)[:,:,0]-gender_seq_dict['person1']['joints'].reshape(B, T, 22, 3)[:,:,0], transf_rotmat['person1'])
                        rel_root_transl['a2b'] = torch.matmul(gender_seq_dict['person1']['joints'].reshape(B, T, 22, 3)[:,:,0]-gender_seq_dict['person2']['joints'].reshape(B, T, 22, 3)[:,:,0], transf_rotmat['person2']) 
                    dists = torch.norm(gender_seq_dict['person1']['joints'].reshape(B, T, 22, 3).unsqueeze(3)-gender_seq_dict['person2']['joints'].reshape(B, T, 22, 3).unsqueeze(2), dim=-1)
                    rel_mindis['b2a'], _ = dists.min(dim=-1)
                    rel_mindis['a2b'], _ = dists.min(dim=-2)
                    
                    rel_info = {}
                    for key in ['b2a', 'a2b']:
                        rel_info[key] = torch.cat([transforms.matrix_to_rotation_6d(rel_global_orient[key]), rel_root_transl[key], rel_mindis[key]], dim=-1)  # [B*num_mp, T, 6+3+22]
            
            # calc features
            feature_dict = {}
            for person in ['person1', 'person2']:
                feature_dict[person] = self.primitive_utility.calc_features(canonicalized_primitive_dict[person], use_predicted_joints=True)
                if self.primitive_utility.feature_dim == 276:
                    feature_dict[person]['transl'] = feature_dict[person]['transl'][:, :-1, :]  # [B*num_mp, T, 3]
                    feature_dict[person]['poses_6d'] = feature_dict[person]['poses_6d'][:, :-1, :]  # [B*num_mp, T, 66]
                    feature_dict[person]['joints'] = feature_dict[person]['joints'][:, :-1, :]  # [B*num_mp, T, 22 * 3]   
                            
            if self.mode == 'merged':
                for person in ['person1', 'person2']:
                    gender_batch[person] = []
                    motion_tensor_normalized = self.normalize(self.dict_to_tensor(feature_dict[person]))     # [B*num_mp, T, D]
                    motion_tensor_normalized = motion_tensor_normalized.permute(0, 2, 1).unsqueeze(2)           # [B*num_mp, D, 1, T]
                    history_mask = torch.zeros_like(motion_tensor_normalized, dtype=torch.bool, device=self.device)
                    history_mask[..., :self.cfg.history_length] = True
                    history_motion = torch.zeros_like(motion_tensor_normalized, dtype=torch.float32, device=self.device)
                    history_motion[..., :self.cfg.history_length] = motion_tensor_normalized[..., :self.cfg.history_length]

                    for primitive_idx in range(self.num_primitive):
                        start_idx = primitive_idx * gender_batch_size
                        end_idx = (primitive_idx + 1) * gender_batch_size
                        if self.use_indi_text:
                            primitive_texts = gender_seq_texts[person][start_idx:end_idx]
                            if self.load_text_embedding:
                                unseen_texts = [text for text in primitive_texts if text not in self.text_embedding_dict[person]]
                                if len(unseen_texts) > 0:
                                    self.update_text_embedding_dict(unseen_texts, person, text_sep=self.text_sep, max_segs=self.max_segs)
                                text_embedding = torch.stack([self.text_embedding_dict[person][text] for text in primitive_texts], dim=0)  # [B, 512]
                                if self.text_sep:
                                    text_mask = torch.stack([self.text_mask_dict[person][text] for text in primitive_texts], dim=0)  # [B, max_segs]
                                else:
                                    text_mask = None
               
                        gender_batch[person].append({
                                'gender': [gender_seq_dict[person]['gender']] * gender_batch_size,
                                'betas': gender_seq_dict[person]['betas'][start_idx:end_idx, :-1, :10],
                                'motion_tensor_normalized': motion_tensor_normalized[start_idx:end_idx, ...], # [B, D, 1, T]
                                'history_motion': history_motion[start_idx:end_idx, ...],
                                'history_mask': history_mask[start_idx:end_idx, ...],
                                'transf_rotmat': transf_rotmat[person][start_idx:end_idx, ...],
                                'transf_transl': transf_transl[person][start_idx:end_idx, ...],
                                'start_frame': gender_seq_dict[person]['start_frame'][start_idx:end_idx],
                                'total_frames': gender_seq_dict[person]['total_frames'][start_idx:end_idx],
                                'history_length': self.cfg.history_length,
                                'future_length': self.cfg.future_length,
                            })
                        if self.use_indi_text:
                            gender_batch[person][-1]['texts'] = primitive_texts
                            if self.load_text_embedding:
                                gender_batch[person][-1]['text_embedding'] = text_embedding
                                gender_batch[person][-1]['text_mask'] = text_mask
                        if self.padding:
                            gender_batch[person][-1]['primitive_padding_mask'] = gender_seq_dict[person]['primitive_padding_mask'][start_idx:end_idx, ...]
                gender_batch['interaction'] = []
                for primitive_idx in range(self.num_primitive):
                    start_idx = primitive_idx * gender_batch_size
                    end_idx = (primitive_idx + 1) * gender_batch_size
                    primitive_texts = gender_seq_texts['interaction'][start_idx:end_idx]
                    if self.load_text_embedding:
                        unseen_texts = [text for text in primitive_texts if text not in self.text_embedding_dict['interaction']]
                        if len(unseen_texts) > 0:
                            self.update_text_embedding_dict(unseen_texts, 'interaction', text_sep=self.text_sep, max_segs=self.max_segs)
                        text_embedding = torch.stack([self.text_embedding_dict['interaction'][text] for text in primitive_texts], dim=0)  # [B, 512]
                        if self.text_sep:
                            text_mask = torch.stack([self.text_mask_dict['interaction'][text] for text in primitive_texts], dim=0)  # [B, max_segs]
                        else:
                            text_mask = None
                    gender_batch['interaction'].append({
                            'texts': primitive_texts, 
                            'rel_pose_b2a': rel_pose['b2a'][start_idx:end_idx],
                            'rel_pose_a2b': rel_pose['a2b'][start_idx:end_idx],
                        })
                    if self.load_text_embedding:
                        gender_batch['interaction'][-1]['text_embedding'] = text_embedding
                        gender_batch['interaction'][-1]['text_mask'] = text_mask
                    if self.normalize_relpose:
                        gender_batch['interaction'][-1]['rel_pose_b2a'] = self.normalize_rel_pose(gender_batch['interaction'][-1]['rel_pose_b2a'])
                        gender_batch['interaction'][-1]['rel_pose_a2b'] = self.normalize_rel_pose(gender_batch['interaction'][-1]['rel_pose_a2b'])
                    if self.use_interaction_model:
                        gender_batch['interaction'][-1].update({
                            'rel_info_b2a': self.normalize_rel_info(rel_info['b2a'][start_idx:end_idx, self.cfg.history_length:-1]),
                            'rel_info_a2b': self.normalize_rel_info(rel_info['a2b'][start_idx:end_idx, self.cfg.history_length:-1]),
                        })
                    
            elif self.mode == 'sep':
                for person in ['person1', 'person2']:
                    motion_tensor_normalized = self.normalize(self.dict_to_tensor(feature_dict[person]))
                    motion_tensor_normalized = motion_tensor_normalized.permute(0, 2, 1).unsqueeze(2)                   # [B*num_mp, D, 1, T]
                    history_mask = torch.zeros_like(motion_tensor_normalized, dtype=torch.bool, device=self.device)
                    history_mask[..., :self.cfg.history_length] = True
                    history_motion = torch.zeros_like(motion_tensor_normalized, dtype=torch.float32, device=self.device)
                    history_motion[..., :self.cfg.history_length] = motion_tensor_normalized[..., :self.cfg.history_length]

                    for primitive_idx in range(self.num_primitive):
                        start_idx = primitive_idx * gender_batch_size
                        end_idx = (primitive_idx + 1) * gender_batch_size
                        if self.use_indi_text:
                            primitive_texts = gender_seq_texts[person][start_idx:end_idx]
                            if self.load_text_embedding:
                                unseen_texts = [text for text in primitive_texts if text not in self.text_embedding_dict[person]]
                                if len(unseen_texts) > 0:
                                    self.update_text_embedding_dict(unseen_texts, person, text_sep=self.text_sep, max_segs=self.max_segs)
                                text_embedding = torch.stack([self.text_embedding_dict[person][text] for text in primitive_texts], dim=0)  # [B, 512]
                                if self.text_sep:
                                    text_mask = torch.stack([self.text_mask_dict[person][text] for text in primitive_texts], dim=0)  # [B, max_segs]
                                else:
                                    text_mask = None
                        gender_batch.append(
                            {
                                'gender': [gender_seq_dict[person]['gender']] * gender_batch_size,
                                'betas': gender_seq_dict[person]['betas'][start_idx:end_idx, :-1, :10],
                                'motion_tensor_normalized': motion_tensor_normalized[start_idx:end_idx, ...], # [B, D, 1, T]
                                'history_motion': history_motion[start_idx:end_idx, ...],
                                'history_mask': history_mask[start_idx:end_idx, ...],
                                'transf_rotmat': transf_rotmat[person][start_idx:end_idx, ...],
                                'transf_transl': transf_transl[person][start_idx:end_idx, ...],
                                'history_length': self.cfg.history_length,
                                'future_length': self.cfg.future_length,
                            }
                        )
                        if self.use_indi_text:
                            gender_batch[-1]['texts'] = primitive_texts
                            if self.load_text_embedding:
                                gender_batch[-1]['text_embedding'] = text_embedding
                                gender_batch[-1]['text_mask'] = text_mask
                        if self.padding:
                            gender_batch[-1]['primitive_padding_mask'] = gender_seq_dict[person]['primitive_padding_mask'][start_idx:end_idx, ...]
                selector = torch.cat([torch.ones(gender_batch_size), torch.zeros(gender_batch_size)])
                selector = selector[torch.randperm(2 * gender_batch_size)]
                
                front_group, back_group = {}, {}
                for key in add_key_list:
                    front_group[key], back_group[key] = [], []
                    for d in gender_batch[:self.num_primitive]:
                        front_group[key] += d[key]
                    for d in gender_batch[self.num_primitive:]:
                        back_group[key] += d[key]
                for key in cat_key_list:
                    front_group[key] = torch.cat([d[key] for d in gender_batch[:self.num_primitive]], dim=0)
                    back_group[key] = torch.cat([d[key] for d in gender_batch[self.num_primitive:]], dim=0)

                front_indices = torch.nonzero(selector[:gender_batch_size], as_tuple=True)[0]  
                back_indices = torch.nonzero(selector[gender_batch_size:], as_tuple=True)[0]  

                selected_batch = []
                for i in range(self.num_primitive):
                    selected_dict = {'history_length': self.cfg.history_length,
                                'future_length': self.cfg.future_length,}
                    for key in front_group.keys():    
                        if key in add_key_list:
                            selected_front = [front_group[key][i] for i in front_indices + i * gender_batch_size] 
                            selected_back = [back_group[key][i] for i in back_indices + i * gender_batch_size]
                            selected_dict[key] = selected_front + selected_back
                        elif key in cat_key_list:
                            selected_front = front_group[key][front_indices + i * gender_batch_size] 
                            selected_back = back_group[key][back_indices + i * gender_batch_size]
                            selected_dict[key] = torch.cat([selected_front, selected_back], dim=0)  
                    selected_batch.append(selected_dict)
                gender_batch = selected_batch
                            
            if batch is None:
                batch = gender_batch
            else:  # concatenate different gender batch
                for primitive_idx in range(self.num_primitive):
                    if self.mode == 'merged':
                        for key_type in self.key_list:
                            if key_type != 'interaction':
                                for key in add_key_list:
                                    batch[key_type][primitive_idx][key] = batch[key_type][primitive_idx][key] + gender_batch[key_type][primitive_idx][key]
                                for key in cat_key_list:
                                    batch[key_type][primitive_idx][key] = torch.cat([batch[key_type][primitive_idx][key], gender_batch[key_type][primitive_idx][key]], dim=0)
                            else:
                                for key in ['texts']:
                                    batch[key_type][primitive_idx][key] = batch[key_type][primitive_idx][key] + gender_batch[key_type][primitive_idx][key]
                                for key in ['rel_pose_b2a', 'rel_pose_a2b', 'rel_info_b2a', 'rel_info_a2b']:
                                    batch[key_type][primitive_idx][key] = torch.cat([batch[key_type][primitive_idx][key], gender_batch[key_type][primitive_idx][key]], dim=0)
                                if self.load_text_embedding:
                                    batch[key_type][primitive_idx]['text_embedding'] = torch.cat([batch[key_type][primitive_idx]['text_embedding'], gender_batch[key_type][primitive_idx]['text_embedding']], dim=0)
                                    batch[key_type][primitive_idx]['text_mask'] = torch.cat([batch[key_type][primitive_idx]['text_mask'], gender_batch[key_type][primitive_idx]['text_mask']], dim=0)   
                    else:
                        for key in add_key_list:
                            batch[primitive_idx][key] = batch[primitive_idx][key] + gender_batch[primitive_idx][key]
                        for key in cat_key_list:
                            batch[primitive_idx][key] = torch.cat([batch[primitive_idx][key], gender_batch[primitive_idx][key]], dim=0)
            # if self.mode == 'merged':
            #     if random.random() < 0.5:
            #         batch['person1'], batch['person2'] = batch['person2'], batch['person1']

        return batch
    
    def get_item(self, idx):
        seq_data = self.dataset[idx]
        num_frames = len(seq_data['motion_p1']['transl'])
        gender = {}
        gender['person1'] = seq_data['motion_p1']['gender']
        gender['person2'] = seq_data['motion_p2']['gender']
        if 'text' in self.weight_scheme:
            start_frame = random.choices(range(num_frames - self.seq_length + 1), weights=seq_data['frame_weights'], k=1)[0]
        else:
            start_frame = random.randint(0, num_frames - self.seq_length)  # [0, num_frames - seq_length], right end inclusive
        primitive_data_list = []
        for frame_idx in range(start_frame, start_frame + self.seq_length - self.primitive_length, self.future_length):
            primitive_data = self.get_primitive(seq_data, frame_idx, frame_idx + self.primitive_length)
            primitive_data_list.append(primitive_data)
        
        gender_seq_texts = {key_type: [] for key_type in self.text_key_list}
        gender_seq_dict = None
        for primitive_idx in range(self.num_primitive):
            primitive_dict = {}
            for person in ['person1', 'person2']:
                primitive_dict[person] = {}
                primitive_dict[person]['gender'] = gender[person]
                for key in ['betas', 'transl', 'global_orient', 'body_pose', 'transf_rotmat', 'transf_transl', 'pelvis_delta', 'joints']:
                    primitive_dict[person][key] = primitive_data_list[primitive_idx]['primitive_dict'][person][key]
                if self.padding:
                    primitive_dict[person]['primitive_padding_mask'] = primitive_data_list[primitive_idx]['primitive_dict'][person]['primitive_padding_mask']
            primitive_texts = {}
            for key_type in self.text_key_list:
                primitive_texts[key_type] = primitive_data_list[primitive_idx]['text_'+key_type]
                gender_seq_texts[key_type].append(primitive_texts[key_type])
            
            if gender_seq_dict is None:
                gender_seq_dict = primitive_dict
            else:
                for person in ['person1', 'person2']:
                    for key in ['betas', 'transl', 'global_orient', 'body_pose', 'transf_rotmat', 'transf_transl', 'pelvis_delta', 'joints']:
                        gender_seq_dict[person][key] = torch.cat([gender_seq_dict[person][key], primitive_dict[person][key]], dim=0)
                    if self.padding:
                        gender_seq_dict[person]['primitive_padding_mask'] = torch.cat([gender_seq_dict[person]['primitive_padding_mask'],
                                                                                           primitive_dict[person]['primitive_padding_mask']], dim=0)

        canonicalized_primitive_dict = {}
        if self.mode == 'merged':
            transf_rotmat, transf_transl = {}, {}
        for person in ['person1', 'person2']:
            gender_seq_dict[person] = tensor_dict_to_device(gender_seq_dict[person], self.device)
            if self.mode == 'merged':
                transf_rotmat[person], transf_transl[person], canonicalized_primitive_dict[person] = self.primitive_utility.canonicalize(copy.deepcopy(gender_seq_dict[person]), use_predicted_joints=True)            
            else:
                _, _, canonicalized_primitive_dict[person] = self.primitive_utility.canonicalize(copy.deepcopy(gender_seq_dict[person]), use_predicted_joints=True)          
        
        if self.mode == 'merged':
            # rel_pose
            rel_rotmat, rel_transl, rel_pose = {}, {}, {}
            rel_rotmat['b2a'], rel_transl['b2a'] = self.primitive_utility.compute_rel_transform_B_in_A(transf_rotmat['person1'], transf_transl['person1'], transf_rotmat['person2'], transf_transl['person2'])
            rel_rotmat['a2b'], rel_transl['a2b'] = self.primitive_utility.compute_rel_transform_B_in_A(transf_rotmat['person2'], transf_transl['person2'], transf_rotmat['person1'], transf_transl['person1'])
            for rel in ['b2a', 'a2b']:
                rel_rotmat[rel] = transforms.matrix_to_rotation_6d(rel_rotmat[rel])
                rel_pose[rel] = torch.cat([rel_rotmat[rel], rel_transl[rel].squeeze(1)], dim=-1)  # [B*num_mp, 6+3]
            
            if self.use_interaction_model:
                B, T, *_ = gender_seq_dict['person1']['joints'].shape
                # reltive transition, relative distance
                rel_global_orient, rel_root_transl, rel_mindis = {}, {}, {}
                if self.primitive_utility.feature_dim == 276:
                    rel_global_orient['b2a'] = gender_seq_dict['person1']['global_orient'].transpose(-1, -2) @ gender_seq_dict['person2']['global_orient']
                    rel_global_orient['a2b'] = gender_seq_dict['person2']['global_orient'].transpose(-1, -2) @ gender_seq_dict['person1']['global_orient']
                    # rel_root_transl['b2a'] = torch.matmul(gender_seq_dict['person1']['global_orient'].transpose(-1, -2), (gender_seq_dict['person2']['transl']-gender_seq_dict['person1']['transl']).unsqueeze(-1)).squeeze(-1)
                    # rel_root_transl['a2b'] = torch.matmul(gender_seq_dict['person2']['global_orient'].transpose(-1, -2), (gender_seq_dict['person1']['transl']-gender_seq_dict['person2']['transl']).unsqueeze(-1)).squeeze(-1)
                    rel_root_transl['b2a'] = torch.matmul((gender_seq_dict['person2']['transl']-gender_seq_dict['person1']['transl']), transf_rotmat['person1'])
                    rel_root_transl['a2b'] = torch.matmul((gender_seq_dict['person1']['transl']-gender_seq_dict['person2']['transl']), transf_rotmat['person2'])
                elif self.primitive_utility.feature_dim == 262:
                    rel_global_orient['b2a'], rel_global_orient['a2b'] = cal_rel_rot(gender_seq_dict['person1']['joints'].reshape(B, T, 22, 3), gender_seq_dict['person2']['joints'].reshape(B, T, 22, 3))
                    rel_root_transl['b2a'] = torch.matmul(gender_seq_dict['person2']['joints'].reshape(B, T, 22, 3)[:,:,0]-gender_seq_dict['person1']['joints'].reshape(B, T, 22, 3)[:,:,0], transf_rotmat['person1'])
                    rel_root_transl['a2b'] = torch.matmul(gender_seq_dict['person1']['joints'].reshape(B, T, 22, 3)[:,:,0]-gender_seq_dict['person2']['joints'].reshape(B, T, 22, 3)[:,:,0], transf_rotmat['person2']) 
                dists = torch.norm(gender_seq_dict['person1']['joints'].reshape(B, T, 22, 3).unsqueeze(3)-gender_seq_dict['person2']['joints'].reshape(B, T, 22, 3).unsqueeze(2), dim=-1)
                rel_mindis['b2a'], _ = dists.min(dim=-1)
                rel_mindis['a2b'], _ = dists.min(dim=-2)
                
                rel_info = {}
                for key in ['b2a', 'a2b']:
                    rel_info[key] = torch.cat([transforms.matrix_to_rotation_6d(rel_global_orient[key]), rel_root_transl[key], rel_mindis[key]], dim=-1)  # [B*num_mp, T, 6+3+22]

        # calc features
        feature_dict = {}
        for person in ['person1', 'person2']:
            feature_dict[person] = self.primitive_utility.calc_features(canonicalized_primitive_dict[person], use_predicted_joints=True)
            if self.primitive_utility.feature_dim == 276:
                feature_dict[person]['transl'] = feature_dict[person]['transl'][:, :-1, :]  # [B*num_mp, T, 3]
                feature_dict[person]['poses_6d'] = feature_dict[person]['poses_6d'][:, :-1, :]  # [B*num_mp, T, 66]
                feature_dict[person]['joints'] = feature_dict[person]['joints'][:, :-1, :]  # [B*num_mp, T, 22 * 3]

        
        data_batch = {} if self.mode == 'merged' else []
        if self.mode == 'merged':
            for person in ['person1', 'person2']:
                data_batch[person] = []
                motion_tensor_normalized = self.normalize(self.dict_to_tensor(feature_dict[person]))     # [B*num_mp, T, D]
                motion_tensor_normalized = motion_tensor_normalized.permute(0, 2, 1).unsqueeze(2)           # [B*num_mp, D, 1, T]
                history_mask = torch.zeros_like(motion_tensor_normalized, dtype=torch.bool, device=self.device)
                history_mask[..., :self.cfg.history_length] = True
                history_motion = torch.zeros_like(motion_tensor_normalized, dtype=torch.float32, device=self.device)
                history_motion[..., :self.cfg.history_length] = motion_tensor_normalized[..., :self.cfg.history_length]

                for primitive_idx in range(self.num_primitive):
                    start_idx = primitive_idx 
                    end_idx = primitive_idx + 1
                    if self.use_indi_text:
                        primitive_texts = gender_seq_texts[person][start_idx:end_idx]
                        if self.load_text_embedding:
                            unseen_texts = [text for text in primitive_texts if text not in self.text_embedding_dict[person]]
                            if len(unseen_texts) > 0:
                                self.update_text_embedding_dict(unseen_texts, person, text_sep=self.text_sep, max_segs=self.max_segs)
                            text_embedding = torch.stack([self.text_embedding_dict[person][text] for text in primitive_texts], dim=0)  # [B, 512]
                            if self.text_sep:
                                text_mask = torch.stack([self.text_mask_dict[person][text] for text in primitive_texts], dim=0)  # [B, max_segs]
                            else:
                                text_mask = None
            
                    data_batch[person].append({
                            'gender': [gender_seq_dict[person]['gender']],
                            'betas': gender_seq_dict[person]['betas'][start_idx:end_idx, :-1, :10],
                            'motion_tensor_normalized': motion_tensor_normalized[start_idx:end_idx, ...], # [B, D, 1, T]
                            'history_motion': history_motion[start_idx:end_idx, ...],
                            'history_mask': history_mask[start_idx:end_idx, ...],
                            'transf_rotmat': transf_rotmat[person][start_idx:end_idx, ...],
                            'transf_transl': transf_transl[person][start_idx:end_idx, ...],
                            'history_length': self.cfg.history_length,
                            'future_length': self.cfg.future_length,
                        })
                    if self.use_indi_text:
                        data_batch[person][-1]['texts'] = primitive_texts
                        if self.load_text_embedding:
                            data_batch[person][-1]['text_embedding'] = text_embedding
                            data_batch[person][-1]['text_mask'] = text_mask
                    if self.padding:
                        data_batch[person][-1]['primitive_padding_mask'] = gender_seq_dict[person]['primitive_padding_mask'][start_idx:end_idx, ...]
            data_batch['interaction'] = []
            for primitive_idx in range(self.num_primitive):        
                start_idx = primitive_idx
                end_idx = (primitive_idx + 1)
                primitive_texts = gender_seq_texts['interaction'][start_idx:end_idx]
                if self.load_text_embedding:
                    unseen_texts = [text for text in primitive_texts if text not in self.text_embedding_dict['interaction']]
                    if len(unseen_texts) > 0:
                        self.update_text_embedding_dict(unseen_texts, 'interaction', text_sep=self.text_sep, max_segs=self.max_segs)
                    text_embedding = torch.stack([self.text_embedding_dict['interaction'][text] for text in primitive_texts], dim=0)  # [B, 512]
                    if self.text_sep:
                        text_mask = torch.stack([self.text_mask_dict['interaction'][text] for text in primitive_texts], dim=0)  # [B, max_segs]
                    else:
                        text_mask = None
                data_batch['interaction'].append({
                        'texts': primitive_texts,
                        'rel_pose_b2a': rel_pose['b2a'][start_idx:end_idx],
                        'rel_pose_a2b': rel_pose['a2b'][start_idx:end_idx],
                    })
                if self.load_text_embedding:
                    data_batch['interaction'][-1]['text_embedding'] = text_embedding
                    data_batch['interaction'][-1]['text_mask'] = text_mask
                if self.use_interaction_model:
                    data_batch['interaction'][-1].update({
                        'rel_info_b2a': rel_info['b2a'][start_idx:end_idx, self.cfg.history_length:-1].reshape(1, -1),
                        'rel_info_a2b': rel_info['a2b'][start_idx:end_idx, self.cfg.history_length:-1].reshape(1, -1),
                    })
            return data_batch
        else:
            for person in ['person1', 'person2']:
                motion_tensor_normalized = self.normalize(self.dict_to_tensor(feature_dict[person]))     # [num_mp, T, D]
                motion_tensor_normalized = motion_tensor_normalized.permute(0, 2, 1).unsqueeze(2)           # [num_mp, D, 1, T]
                history_mask = torch.zeros_like(motion_tensor_normalized, dtype=torch.bool, device=self.device)
                history_mask[..., :self.cfg.history_length] = True
                history_motion = torch.zeros_like(motion_tensor_normalized, dtype=torch.float32, device=self.device)
                history_motion[..., :self.cfg.history_length] = motion_tensor_normalized[..., :self.cfg.history_length]

                for primitive_idx in range(self.num_primitive):
                    start_idx = primitive_idx
                    end_idx = primitive_idx + 1
                    if self.use_indi_text:
                        primitive_texts = gender_seq_texts[person][start_idx:end_idx]
                        if self.load_text_embedding:
                            unseen_texts = [text for text in primitive_texts if text not in self.text_embedding_dict[person]]
                            if len(unseen_texts) > 0:
                                new_text_embeddings = encode_text(self.clip_model, unseen_texts)
                                for idx, text in enumerate(unseen_texts):
                                    self.text_embedding_dict[person][text] = new_text_embeddings[idx]
                            text_embedding = torch.stack([self.text_embedding_dict[person][text] for text in primitive_texts], dim=0)  # [B, 512]
                            if self.text_sep:
                                text_mask = torch.stack([self.text_mask_dict[person][text] for text in primitive_texts], dim=0)  # [B, max_segs]
                            else:
                                text_mask = None
                    data_batch.append(
                        {
                            'gender': [gender_seq_dict[person]['gender']],
                            'betas': gender_seq_dict[person]['betas'][start_idx:end_idx, :-1, :10],
                            'motion_tensor_normalized': motion_tensor_normalized[start_idx:end_idx, ...], # [1, D, 1, T]
                            'history_motion': history_motion[start_idx:end_idx, ...],
                            'history_mask': history_mask[start_idx:end_idx, ...],
                            'history_length': self.cfg.history_length,
                            'future_length': self.cfg.future_length,
                        }
                    )
                    if self.use_indi_text:
                        data_batch[-1]['texts'] = primitive_texts
                        if self.load_text_embedding:
                            data_batch[-1]['text_embedding'] = text_embedding
                            data_batch[-1]['text_mask'] = text_mask
                            
                    if self.padding:
                        data_batch[-1]['primitive_padding_mask'] = gender_seq_dict[person]['primitive_padding_mask'][start_idx:end_idx, ...]
            if random.random() < 0.5:
                return data_batch[:self.num_primitive]
            else:
                return data_batch[self.num_primitive:]


class InterGenDatasetWPERT(WeightedPrimitiveSequenceDataset):
    def __init__(self, dataset_name='interhuman',
                 dataset_path='./data/InterHuman/seq_data_single_interaction_d262_fps30_mirror_exchangeyz',
                 cfg_path='./config_files/config_hydra/motion_primitive/interhuman_h2_f8_r4.yaml',
                 split="train",
                 device='cuda',
                 weight_scheme='uniform',
                 prob_static=0.0,
                 enforce_gender=None,
                 enforce_zero_beta=None,
                 load_data=True,
                 text_tolerance=0.0,
                 body_type='smplh',
                 seed_only=False,
                 use_frame_weights=True,
                 mode='merged', # 'sep' or 'merged'
                 text_sep = False,
                 max_segs = 20,
                 motion_repr = {'joints': 22 * 3,
                    'joints_delta': 22 * 3,
                    'body_pose': 21 * 6,
                    'feet_contact': 4,},
                 **kwargs):
        self.dataset_name = dataset_name
        self.dataset_path = dataset_path
        self.split = split
        self.device = device
        self.weight_scheme = weight_scheme
        self.prob_static = prob_static
        self.enforce_gender = enforce_gender
        self.enforce_zero_beta = enforce_zero_beta
        print('enforce_gender: ', enforce_gender)
        print('enforce_zero_beta: ', enforce_zero_beta)
        
        self.text_tolerance = text_tolerance
        self.seed_only = seed_only
        self.mode = mode
        self.text_sep = text_sep
        self.max_segs = max_segs
        self.sep_mode = kwargs.get('sep_mode', 0)
        self.padding = kwargs.get('padding', False)
        self.normalize_relpose = kwargs.get('normalize_relpose', False)
        self.use_interaction_model = kwargs.get('use_interaction_model', False)
        self.key_list = ['person1', 'person2', 'interaction'] if self.mode=='merged' else ['person1', 'person2']
        self.feet_thre = 0.001
        self.n_joints = 22
        
        self.clip_version = kwargs.get('clip_version', 'ViT-B/32')
        self.load_text_embedding = kwargs.get('load_text_embedding', False)
        self.use_indi_text = kwargs.get('use_indi_text', False)
        self.text_key_list = ['person1', 'person2', 'interaction'] if (self.use_indi_text or self.mode=='sep') else ['interaction']
        
        self.primitive_utility = PrimitiveUtility(device=self.device, body_type=body_type, motion_repr=motion_repr)
        self.motion_repr = self.primitive_utility.motion_repr

        # cfg_path = Path(dataset_path, 'config.yaml')
        with open(cfg_path, 'r') as f:
            self.cfg = OmegaConf.load(f)
        self.target_fps = self.cfg.fps
        # self.downsample_rate = 120 // self.target_fps
        self.history_length = self.cfg.history_length
        self.future_length = self.cfg.future_length
        self.primitive_length = self.history_length + self.future_length
        self.num_primitive = self.cfg.num_primitive
        if seed_only:
            assert self.num_primitive == 1
        self.seq_length = self.history_length + self.future_length * self.num_primitive + 1
        self.min_length = self.history_length + self.future_length + 1

        if load_data:
            with open(pjoin(dataset_path, f'{split}.pkl'), 'rb') as f:
                dataset = pickle.load(f)
            if not self.padding:
                dataset = [data for data in dataset if len(data['motion_p1']['trans']) >= self.seq_length]
            
            elements_to_remove = ['7220', '7221', '6028', '7543', '6940', '4434', '7561', '4385']
            dataset = [data for data in dataset if (data['seq_name'] not in elements_to_remove and data['motion_p1']['trans'].shape[0]>=self.min_length)]
            
            # dataset = [data for data in dataset if data['seq_name'] == '7662']

            for data in dataset:
                if self.padding:
                    T = data['motion_p1']['trans'].shape[0]
                    if T < self.seq_length:
                        pad_len = self.seq_length - T
                        for person in ['motion_p1', 'motion_p2']:
                            for key in ['trans', 'global_orient', 'pose_body', 'joints']:
                                last_frame = data[person][key][-1:]
                                padding = np.repeat(last_frame, pad_len, axis=0)
                                data[person][key] = np.concatenate([
                                    data[person][key],
                                    padding
                                ], axis=0)

                            # padding_mask
                            data[person]['padding_mask'] = np.concatenate([
                                np.zeros(T, dtype=np.bool_),
                                np.ones(pad_len, dtype=np.bool_)
                            ], axis=0)
                    else:
                        for person in ['motion_p1', 'motion_p2']:
                            data[person]['padding_mask'] = np.zeros(T, dtype=np.bool_)
                        
                def convert_motion(motion, gender, enforce_zero_beta):
                    betas = torch.from_numpy(motion['betas'].astype(np.float32))
                    if enforce_zero_beta:
                        betas = torch.zeros_like(betas)
                    transl = torch.from_numpy(motion['trans'].astype(np.float32))
                    global_orient = transforms.axis_angle_to_matrix(torch.from_numpy(motion['global_orient'].astype(np.float32)))
                    body_pose = torch.from_numpy(motion['pose_body'].astype(np.float32)).reshape(-1, 21, 6) # [T, 21, 6]            
                    pelvis_delta = torch.from_numpy(motion['pelvis_delta'].astype(np.float32))              # [3]
                    joints = torch.from_numpy(motion['joints'].astype(np.float32))                          # [T, 22, 3]
                    result = {
                        'gender': gender,
                        'betas': betas,
                        'transl': transl,
                        'global_orient': global_orient,
                        'body_pose': body_pose,
                        'pelvis_delta': pelvis_delta,
                        'joints': joints,
                    }
                    if self.padding:
                        result['padding_mask'] = motion['padding_mask']
                        
                    return result
                gender_p1 = self.enforce_gender if self.enforce_gender is not None else data['motion_p1']['gender']
                gender_p2 = self.enforce_gender if self.enforce_gender is not None else data['motion_p2']['gender']
                data['motion_p1'] = convert_motion(data['motion_p1'], gender_p1, self.enforce_zero_beta)
                data['motion_p2'] = convert_motion(data['motion_p2'], gender_p2, self.enforce_zero_beta)
            
            print('num of sequences: ', len(dataset))
            
            # assign sampling weights to each sequence
            for data in dataset:
                if 'uniform' in weight_scheme:
                    data['weight'] = 1.0
                elif 'length' in weight_scheme:
                    if self.padding and len(data['motion_p1']['transl'])==self.seq_length:
                        data['weight'] = len(data['motion_p1']['transl'])-sum(data['motion_p1']['padding_mask'])
                    else:
                        data['weight'] = len(data['motion_p1']['transl'])
            print('finish first assigning seq weights')

            # overfit using one sequence
            # if 'overfit' in weight_scheme:
            #     seq_id = int(weight_scheme.split('overfit:')[-1].split('_')[0])
            #     for idx, data in enumerate(dataset):
            #         if idx == seq_id:
            #             data['weight'] = 1.0
            #         else:
            #             data['weight'] = 0.0

            seq_weights = np.array([data['weight'] for data in dataset])
            seq_weights = seq_weights / seq_weights.sum()

            self.dataset = dataset
            self.seq_weights = seq_weights
        self._curr_test_index = 0
        
        # load or calc mean and std
        self.tensor_mean_device_dict = {}
        file_name = f'mean_std_h{self.history_length}_f{self.future_length}'
        
        suffix = '_padding' if self.padding else ''
        mean_std_path = Path(dataset_path, f'{file_name}{suffix}_relative.pkl')
        
        if mean_std_path.exists():
            print(f'loading mean and std from {mean_std_path}')
            with open(mean_std_path, 'rb') as f:
                self.tensor_mean, self.tensor_std = pickle.load(f)  # [1, 1, D]
        else:
            assert self.split == 'train'
            print('calculating mean and std using train split')
            result = self.calc_mean_std()

            self.tensor_mean, self.tensor_std = result
            self.rel_mean, self.rel_std = None, None
            with open(mean_std_path, 'wb') as f:
                pickle.dump((self.tensor_mean.detach().cpu(), self.tensor_std.detach().cpu()), f)
    
        # load clip model, get train text embeddings
        if self.load_text_embedding:
            self.load_and_freeze_clip(clip_version=self.clip_version, device=self.device)
            self.dim_embed_text = self.clip_model.ln_final.normalized_shape[0]
            suffix = '' if self.clip_version == 'ViT-B/32' else f"_{self.clip_version.replace('/', '')}"
            self.embedding_path = {}
            embedding_path = {}
            for key_type in self.text_key_list:
                if text_sep:
                    self.embedding_path[key_type] = embedding_path[key_type] = Path(dataset_path, f'{split}_{key_type}_text_embedding_dict_textsep_sepmode{self.sep_mode}{suffix}.pkl')
                else:
                    self.embedding_path[key_type] = embedding_path[key_type] = Path(dataset_path, f'{split}_{key_type}_text_embedding_dict{suffix}.pkl')
            self.text_embedding_dict = {}
            if text_sep:
                self.text_mask_dict = {}
        
            for key_type in self.text_key_list:
                if embedding_path[key_type].exists():
                    print(f"Loading text_{key_type} embeddings from {embedding_path[key_type]}!")
                    with open(embedding_path[key_type], 'rb') as f:
                        self.text_embedding_dict[key_type] = pickle.load(f)
                    if text_sep:
                        with open(Path(str(embedding_path[key_type]).replace('text_embedding_dict', 'text_mask_dict')), 'rb') as f:
                            self.text_mask_dict[key_type] = pickle.load(f)
                else:
                    print('Calculating text embeddings')
                    raw_texts = []
                    for data in self.dataset:
                        if f'frame_labels_{key_type}' in data:
                            raw_texts.extend([seg['proc_label'] for seg in data['frame_labels_' + key_type]])

                    raw_texts = list(set(raw_texts))
                    num_texts = len(raw_texts)
                    print(f'num of unique texts_{key_type}: ', len(raw_texts))
                        
                    # get text embeddings by batch
                    text_embeddings = []
                    text_mask = []
                    batch_start_idx = 0
                    while batch_start_idx < num_texts:
                        batch_end_idx = min(batch_start_idx + 256, num_texts)
                        text_embeddings_temp = self.encode_text(raw_texts[batch_start_idx:batch_end_idx], text_sep=text_sep, max_segs=max_segs, sep_mode=self.sep_mode)
                        if text_sep:
                            text_embeddings.append(text_embeddings_temp[0])
                            text_mask.append(text_embeddings_temp[1])
                        else:
                            text_embeddings.append(text_embeddings_temp)
                        batch_start_idx = batch_end_idx
                    text_embeddings = torch.cat(text_embeddings, dim=0).detach().cpu().numpy()
                
                    self.text_embedding_dict[key_type] = {raw_texts[idx]: text_embeddings[idx] for idx in range(num_texts)}
                    if text_sep:
                        self.text_embedding_dict[key_type][''] = np.zeros((self.max_segs, self.dim_embed_text)).astype(np.float32)
                    else:
                        self.text_embedding_dict[key_type][''] = np.zeros(self.dim_embed_text).astype(np.float32)  # for empty text have zero embedding, compatible with mdm text masking
                    with open(embedding_path[key_type], 'wb') as f:
                        pickle.dump(self.text_embedding_dict[key_type], f)
                    if text_sep:
                        text_mask = torch.cat(text_mask, dim=0).detach().cpu().numpy()
                        self.text_mask_dict[key_type] = {raw_texts[idx]: text_mask[idx] for idx in range(num_texts)}
                        self.text_mask_dict[key_type][''] = np.zeros(max_segs).astype(np.bool_)  # for empty text have zero embedding, compatible with mdm text masking
                        with open(Path(str(embedding_path[key_type]).replace('text_embedding_dict', 'text_mask_dict')), 'wb') as f:
                            pickle.dump(self.text_mask_dict[key_type], f)
                
                for key in self.text_embedding_dict[key_type]:
                    self.text_embedding_dict[key_type][key] = torch.from_numpy(self.text_embedding_dict[key_type][key]).to(dtype=torch.float32, device=self.device)
                    if text_sep:
                        self.text_mask_dict[key_type][key] = torch.from_numpy(self.text_mask_dict[key_type][key]).to(dtype=torch.bool, device=self.device)

    def load_and_freeze_clip(self, clip_version, device='cpu'):
        self.clip_model, _= clip.load(clip_version, device=device,
                                                jit=False)  # Must set jit=False for training
        clip.model.convert_weights(self.clip_model)  # Actually this line is unnecessary since clip by default already on float16

        # Freeze CLIP weights
        self.clip_model.eval()
        for p in self.clip_model.parameters():
            p.requires_grad = False
    
    def encode_text(self, raw_text, force_empty_zero=True, text_sep=False, max_segs = 20, sep_mode=0):
        import pandas as pd
        device = next(self.clip_model.parameters()).device
        embed_dim = self.dim_embed_text
        batch_size = len(raw_text)

        if not text_sep:
            with torch.no_grad():
                texts = clip.tokenize(raw_text, truncate=True).to(device)  # [B, context_length]
                text_embedding = self.clip_model.encode_text(texts).float()  # [B, 512]
                if force_empty_zero:
                    empty_text = [t == '' for t in raw_text]
                    text_embedding[empty_text, :] = 0
                return text_embedding
                
        raw_series = pd.Series(raw_text).str.strip().str.rstrip('.')
        if sep_mode == 0:
            split_df = raw_series.str.split(r'[,.]', n=max_segs - 1, expand=True)
        elif sep_mode == 1:
            split_df = raw_series.str.split(r'\band\b|\bwhile\b|,|\.', n=max_segs - 1, expand=True)
        split_df = split_df.fillna('').astype(str).applymap(str.strip)

        split_df = split_df.reindex(columns=range(max_segs), fill_value='')
        
        segs_matrix = split_df.values
        segs_flat = segs_matrix.reshape(-1).tolist()

        text_mask = (segs_matrix == '').astype(bool)
        text_mask = torch.tensor(text_mask, dtype=torch.bool, device=device)

        tokenized = clip.tokenize(segs_flat, truncate=True).to(device)  # [B*max_segs, context_length]
        text_embedding = self.clip_model.encode_text(tokenized).float()      # [B*max_segs, 512]
        text_embedding = text_embedding.view(batch_size, max_segs, embed_dim)  # [B, max_segs, 512]

        if force_empty_zero:
            text_embedding[text_mask] = 0

        return text_embedding, text_mask
    
    def get_batch_idx(self, batch_size=8):
        if self.split == 'test':
            start_idx = self._curr_test_index
            end_idx = start_idx + batch_size
            batch_idx = np.arange(start_idx, min(end_idx, len(self.dataset)))
            self._curr_test_index = end_idx if end_idx < len(self.dataset) else 0
            return batch_idx
        else:
            batch_idx = np.random.choice(len(self.dataset), size=batch_size, replace=True, p=self.seq_weights)
            return batch_idx

    def update_text_embedding_dict(self, new_texts, key_type, text_sep=False, max_segs=20):
        new_text_embeddings = self.encode_text(new_texts, text_sep=text_sep, max_segs=max_segs, sep_mode=self.sep_mode)
        for idx, text in enumerate(new_texts):
            if text_sep:
                self.text_embedding_dict[key_type][text] = new_text_embeddings[0][idx]
                self.text_mask_dict[key_type][text] = new_text_embeddings[1][idx]
            else:
                self.text_embedding_dict[key_type][text] = new_text_embeddings[idx]

    def calc_mean_std(self, batch_size=512):
        if self.future_length == 1:
            batch_size = min(batch_size, 64)
        all_mp_data, all_rel_info = [], []
        for seq_data in self.dataset:
            # exchange person1 and person2
            if random.random() < 0.5:
                seq_data['motion_p1'], seq_data['motion_p2'] = seq_data['motion_p2'], seq_data['motion_p1']
                if self.use_indi_text:
                    seq_data['frame_labels_person1'], seq_data['frame_labels_person2'] = seq_data['frame_labels_person2'], seq_data['frame_labels_person1']
            motion_data_p1 = seq_data['motion_p1']
            num_frames = motion_data_p1['body_pose'].shape[0]
            primitive_data_list = []
            for start_frame in range(0, num_frames - self.primitive_length, self.future_length):
                end_frame = start_frame + self.primitive_length
                primitive_data_list.append(self.get_primitive(seq_data, start_frame, end_frame, skip_text=True))
                
            primitive_dict = {}
            for person in ['person1', 'person2']:
                primitive_dict[person] = {}
                primitive_dict[person]['gender'] = {primitive_data_list[0]['primitive_dict'][person]['gender']}
                for key in ['betas', 'transl', 'global_orient', 'body_pose', 'transf_rotmat', 'transf_transl', 'pelvis_delta', 'joints']:
                    primitive_dict[person][key] = torch.cat([data['primitive_dict'][person][key] for data in primitive_data_list], dim=0)
                if self.padding:
                    primitive_dict[person]['primitive_padding_mask'] = torch.cat([data['primitive_dict'][person]['primitive_padding_mask'] for data in primitive_data_list], dim=0)
                primitive_dict[person] = tensor_dict_to_device(primitive_dict[person], self.device)

            # split primitive_dict into batches
            batch_start_idx = 0
            while batch_start_idx < len(primitive_dict['person1']['body_pose']):
                batch_primitive_dict = {}
                canonicalized_primitive_dict = {}
                batch_end_idx = min(batch_start_idx + batch_size, len(primitive_dict['person1']['body_pose']))
                for person in ['person1', 'person2']:
                    batch_primitive_dict[person] = {}
                    batch_primitive_dict[person] = {key: primitive_dict[person][key][batch_start_idx:batch_end_idx] for key in ['betas', 'transl', 'global_orient', 'body_pose', 'transf_rotmat', 'transf_transl', 'pelvis_delta', 'joints']}
                    batch_primitive_dict[person]['gender'] = primitive_dict[person]['gender']
                
                transf_rotmat, transf_transl, canonicalized_primitive_dict['person1'] = self.primitive_utility.canonicalize(copy.deepcopy(batch_primitive_dict['person1']), use_predicted_joints=True)
                canonicalized_primitive_dict['person2'] = self.primitive_utility.relative_canonicalize(copy.deepcopy(batch_primitive_dict['person2']), transf_rotmat, transf_transl)
                                
                feature_dict = {}
                for person in ['person1', 'person2']:
                    feature_dict[person] = self.primitive_utility.calc_features(canonicalized_primitive_dict[person], use_predicted_joints=True)
                    if self.primitive_utility.feature_dim == 276:
                        feature_dict[person]['transl'] = feature_dict[person]['transl'][:, :-1, :]  # [B*num_mp, T, 3]
                        feature_dict[person]['poses_6d'] = feature_dict[person]['poses_6d'][:, :-1, :]  # [B*num_mp, T, 66]
                        feature_dict[person]['joints'] = feature_dict[person]['joints'][:, :-1, :]  # [B*num_mp, T, 22 * 3]
                    motion_tensor = self.dict_to_tensor(feature_dict[person]).detach().cpu()    # [num_primitive, T, D]
                    if self.padding:
                        mask_slice = primitive_dict[person]['primitive_padding_mask'][batch_start_idx:batch_end_idx, -1]  # [B]
                        valid_indices = torch.nonzero(~mask_slice, as_tuple=True)[0].detach().cpu()  # select valid
                        motion_tensor = motion_tensor[valid_indices]
                    all_mp_data.append(motion_tensor)
                batch_start_idx = batch_end_idx

        all_mp_data = torch.cat(all_mp_data, dim=0)                 # [2*N, T, D]
        tensor_mean = all_mp_data.mean(dim=[0, 1], keepdim=True)    # [1, 1, D]
        tensor_std = all_mp_data.std(dim=[0, 1], keepdim=True)      # [1, 1, D]
        return tensor_mean.to(self.device), tensor_std.to(self.device)

    def get_primitive(self, seq_data, start_frame, end_frame, skip_text=False):
        """end_frame included"""
        primitive_dict = {}
        for person, motion_data in zip(['person1', 'person2'], [seq_data['motion_p1'], seq_data['motion_p2']]):
            primitive_dict[person] = {
                'gender': motion_data['gender'],
                'betas': motion_data['betas'].expand(1, self.primitive_length + 1, 10),
                'transl': motion_data['transl'][start_frame:end_frame + 1].unsqueeze(0),
                'global_orient': motion_data['global_orient'][start_frame:end_frame + 1].unsqueeze(0),
                'body_pose': motion_data['body_pose'][start_frame:end_frame + 1].unsqueeze(0),
                'pelvis_delta': motion_data['pelvis_delta'].unsqueeze(0),
                'joints': motion_data['joints'][start_frame:end_frame + 1].unsqueeze(0),
                'transf_rotmat': torch.eye(3).unsqueeze(0),
                'transf_transl': torch.zeros(1, 1, 3),
            }
            if self.padding:
                padding_mask_full = seq_data[f'motion_p{person[-1]}']['padding_mask'][start_frame:end_frame + 1]  # shape [T+1]
                history_mask = torch.tensor(padding_mask_full[:self.history_length], dtype=torch.bool)
                future_mask = padding_mask_full[self.history_length:-1]
                future_flag = torch.tensor(future_mask.any(), dtype=torch.bool)
                primitive_dict[person]['primitive_padding_mask'] = torch.cat([history_mask, future_flag.unsqueeze(0)], dim=0).unsqueeze(0) # (1, history_length + 1)
        
        texts = {key: [] for key in self.text_key_list}
        for key_type in self.text_key_list:
            if not skip_text and f'frame_labels_{key_type}' in seq_data:
                future_start = (start_frame + self.history_length) / self.target_fps
                future_end = (start_frame + self.history_length + self.future_length - 1) / self.target_fps
                # print('text tolerance: ', self.text_tolerance)
                for seg in seq_data[f'frame_labels_{key_type}']:
                    if have_overlap([seg['start_t'], seg['end_t']], [future_start - self.text_tolerance, future_end + self.text_tolerance]):
                        texts[key_type].append(seg['proc_label'])

        output = {}
        for key_type in self.text_key_list:
            output['text_'+key_type] = random.choice(texts[key_type]) if len(texts[key_type]) > 0 else ''
        output['primitive_dict'] = primitive_dict
        return output

    def get_batch(self, batch_size=8):
        self.time = time.time()
        seq_list = []
        batch_idx = self.get_batch_idx(batch_size)
        add_key_list = ['gender']
        cat_key_list = ['betas', 'motion_tensor_normalized', 'history_motion', 'history_mask', 'transf_rotmat', 'transf_transl', 'start_frame', 'total_frames']
        if self.padding:
            cat_key_list.append('primitive_padding_mask')
        if self.use_indi_text:
            add_key_list.append('texts')
            if self.load_text_embedding:
                cat_key_list.append('text_embedding')
            if self.text_sep:
                cat_key_list.append('text_mask')
        
        for seq_idx in batch_idx:
            seq_data = dict(self.dataset[seq_idx])
            # exchange person1 and person2
            if random.random() < 0.5:
                seq_data['motion_p1'], seq_data['motion_p2'] = seq_data['motion_p2'], seq_data['motion_p1']
                if self.use_indi_text:
                    seq_data['frame_labels_person1'], seq_data['frame_labels_person2'] = seq_data['frame_labels_person2'], seq_data['frame_labels_person1']
            num_frames = len(seq_data['motion_p1']['transl'])
            if 'text' in self.weight_scheme:
                start_frame = random.choices(range(num_frames - self.seq_length + 1), weights=seq_data['frame_weights'], k=1)[0]
            else:
                start_frame = random.randint(0, num_frames - self.seq_length)  # [0, num_frames - seq_length], right end inclusive
            primitive_data_list = []
            for frame_idx in range(start_frame, start_frame + self.seq_length - self.primitive_length, self.future_length):
                primitive_data = self.get_primitive(seq_data, frame_idx, frame_idx + self.primitive_length)
                for person in ['person1', 'person2']:
                    primitive_data['primitive_dict'][person]['start_frame'] = torch.tensor(frame_idx).view(1)
                    primitive_data['primitive_dict'][person]['total_frames'] = torch.tensor(num_frames).view(1)
                primitive_data_list.append(primitive_data)
            seq_list.append(primitive_data_list)

        # sort batch by gender
        batch = None
        for gender in ['female', 'male', 'neutral']:
            gender_idx = [idx for idx in range(len(seq_list)) if seq_list[idx][0]['primitive_dict']['person1']['gender'] == gender]
            if len(gender_idx) == 0:
                continue
            gender_seq_list = [seq_list[i] for i in gender_idx]
            gender_batch_size = len(gender_idx)
            gender_batch = {} if self.mode == 'merged' else []
            
            gender_seq_texts = {key_type: None for key_type in self.text_key_list}
            gender_seq_dict = None
            for primitive_idx in range(self.num_primitive):
                primitive_dict = {}
                for person in ['person1', 'person2']:
                    primitive_dict[person] = {}
                    primitive_dict[person]['gender'] = gender
                    for key in ['betas', 'transl', 'global_orient', 'body_pose', 'transf_rotmat', 'transf_transl', 'pelvis_delta', 'joints', 'start_frame', 'total_frames']:
                        primitive_dict[person][key] = torch.cat([mp_seq[primitive_idx]['primitive_dict'][person][key] for mp_seq in gender_seq_list], dim=0)
                    if self.padding:
                        primitive_dict[person]['primitive_padding_mask'] = torch.cat([mp_seq[primitive_idx]['primitive_dict'][person]['primitive_padding_mask'] for mp_seq in gender_seq_list], dim=0)
                primitive_texts = {}
                for key_type in self.text_key_list:
                    primitive_texts[key_type] = [mp_seq[primitive_idx]['text_'+key_type] for mp_seq in gender_seq_list]
                    gender_seq_texts[key_type] = primitive_texts[key_type] if gender_seq_texts[key_type] is None else gender_seq_texts[key_type] + primitive_texts[key_type]
                
                if gender_seq_dict is None:
                    gender_seq_dict = primitive_dict
                else:
                    for person in ['person1', 'person2']:
                        for key in ['betas', 'transl', 'global_orient', 'body_pose', 'transf_rotmat', 'transf_transl', 'pelvis_delta', 'joints', 'start_frame', 'total_frames']:
                            gender_seq_dict[person][key] = torch.cat([gender_seq_dict[person][key], primitive_dict[person][key]], dim=0)
                        if self.padding:
                            gender_seq_dict[person]['primitive_padding_mask'] = torch.cat([gender_seq_dict[person]['primitive_padding_mask'], 
                                                                                           primitive_dict[person]['primitive_padding_mask']], dim=0)

            canonicalized_primitive_dict = {}
            transf_rotmat, transf_transl = {}, {}
            for person in ['person1', 'person2']:
                gender_seq_dict[person] = tensor_dict_to_device(gender_seq_dict[person], self.device)
            transf_rotmat['person1'], transf_transl['person1'], canonicalized_primitive_dict['person1'] = self.primitive_utility.canonicalize(copy.deepcopy(gender_seq_dict['person1']), use_predicted_joints=True)  
            canonicalized_primitive_dict['person2'] = self.primitive_utility.relative_canonicalize(copy.deepcopy(gender_seq_dict['person2']), transf_rotmat['person1'], transf_transl['person1'])
            transf_rotmat['person2'], transf_transl['person2'] = copy.deepcopy(transf_rotmat['person1']), copy.deepcopy(transf_transl['person1'])            
            # calc features
            feature_dict = {}
            for person in ['person1', 'person2']:
                feature_dict[person] = self.primitive_utility.calc_features(canonicalized_primitive_dict[person], use_predicted_joints=True)
                if self.primitive_utility.feature_dim == 276:
                    feature_dict[person]['transl'] = feature_dict[person]['transl'][:, :-1, :]  # [B*num_mp, T, 3]
                    feature_dict[person]['poses_6d'] = feature_dict[person]['poses_6d'][:, :-1, :]  # [B*num_mp, T, 66]
                    feature_dict[person]['joints'] = feature_dict[person]['joints'][:, :-1, :]  # [B*num_mp, T, 22 * 3]   
                            
            if self.mode == 'merged':
                for person in ['person1', 'person2']:
                    gender_batch[person] = []
                    motion_tensor_normalized = self.normalize(self.dict_to_tensor(feature_dict[person]))     # [B*num_mp, T, D]
                    motion_tensor_normalized = motion_tensor_normalized.permute(0, 2, 1).unsqueeze(2)           # [B*num_mp, D, 1, T]
                    history_mask = torch.zeros_like(motion_tensor_normalized, dtype=torch.bool, device=self.device)
                    history_mask[..., :self.cfg.history_length] = True
                    history_motion = torch.zeros_like(motion_tensor_normalized, dtype=torch.float32, device=self.device)
                    history_motion[..., :self.cfg.history_length] = motion_tensor_normalized[..., :self.cfg.history_length]

                    for primitive_idx in range(self.num_primitive):
                        start_idx = primitive_idx * gender_batch_size
                        end_idx = (primitive_idx + 1) * gender_batch_size
                        if self.use_indi_text:
                            primitive_texts = gender_seq_texts[person][start_idx:end_idx]
                            if self.load_text_embedding:
                                unseen_texts = [text for text in primitive_texts if text not in self.text_embedding_dict[person]]
                                if len(unseen_texts) > 0:
                                    self.update_text_embedding_dict(unseen_texts, person, text_sep=self.text_sep, max_segs=self.max_segs)
                                text_embedding = torch.stack([self.text_embedding_dict[person][text] for text in primitive_texts], dim=0)  # [B, 512]
                                if self.text_sep:
                                    text_mask = torch.stack([self.text_mask_dict[person][text] for text in primitive_texts], dim=0)  # [B, max_segs]
                                else:
                                    text_mask = None
               
                        gender_batch[person].append({
                                'gender': [gender_seq_dict[person]['gender']] * gender_batch_size,
                                'betas': gender_seq_dict[person]['betas'][start_idx:end_idx, :-1, :10],
                                'motion_tensor_normalized': motion_tensor_normalized[start_idx:end_idx, ...], # [B, D, 1, T]
                                'history_motion': history_motion[start_idx:end_idx, ...],
                                'history_mask': history_mask[start_idx:end_idx, ...],
                                'transf_rotmat': transf_rotmat[person][start_idx:end_idx, ...],
                                'transf_transl': transf_transl[person][start_idx:end_idx, ...],
                                'start_frame': gender_seq_dict[person]['start_frame'][start_idx:end_idx],
                                'total_frames': gender_seq_dict[person]['total_frames'][start_idx:end_idx],
                                'history_length': self.cfg.history_length,
                                'future_length': self.cfg.future_length,
                            })
                        if self.use_indi_text:
                            gender_batch[person][-1]['texts'] = primitive_texts
                            if self.load_text_embedding:
                                gender_batch[person][-1]['text_embedding'] = text_embedding
                                gender_batch[person][-1]['text_mask'] = text_mask
                        if self.padding:
                            gender_batch[person][-1]['primitive_padding_mask'] = gender_seq_dict[person]['primitive_padding_mask'][start_idx:end_idx, ...]
                gender_batch['interaction'] = []
                for primitive_idx in range(self.num_primitive):
                    start_idx = primitive_idx * gender_batch_size
                    end_idx = (primitive_idx + 1) * gender_batch_size
                    primitive_texts = gender_seq_texts['interaction'][start_idx:end_idx]
                    if self.load_text_embedding:
                        unseen_texts = [text for text in primitive_texts if text not in self.text_embedding_dict['interaction']]
                        if len(unseen_texts) > 0:
                            self.update_text_embedding_dict(unseen_texts, 'interaction', text_sep=self.text_sep, max_segs=self.max_segs)
                        text_embedding = torch.stack([self.text_embedding_dict['interaction'][text] for text in primitive_texts], dim=0)  # [B, 512]
                        if self.text_sep:
                            text_mask = torch.stack([self.text_mask_dict['interaction'][text] for text in primitive_texts], dim=0)  # [B, max_segs]
                        else:
                            text_mask = None
                    gender_batch['interaction'].append({
                        'texts': primitive_texts, 
                    })
                    if self.load_text_embedding:
                        gender_batch['interaction'][-1]['text_embedding'] = text_embedding
                        gender_batch['interaction'][-1]['text_mask'] = text_mask  
            elif self.mode == 'sep':
                for person in ['person1', 'person2']:
                    motion_tensor_normalized = self.normalize(self.dict_to_tensor(feature_dict[person]))
                    motion_tensor_normalized = motion_tensor_normalized.permute(0, 2, 1).unsqueeze(2)                   # [B*num_mp, D, 1, T]
                    history_mask = torch.zeros_like(motion_tensor_normalized, dtype=torch.bool, device=self.device)
                    history_mask[..., :self.cfg.history_length] = True
                    history_motion = torch.zeros_like(motion_tensor_normalized, dtype=torch.float32, device=self.device)
                    history_motion[..., :self.cfg.history_length] = motion_tensor_normalized[..., :self.cfg.history_length]

                    for primitive_idx in range(self.num_primitive):
                        start_idx = primitive_idx * gender_batch_size
                        end_idx = (primitive_idx + 1) * gender_batch_size
                        if self.use_indi_text:
                            primitive_texts = gender_seq_texts[person][start_idx:end_idx]
                            if self.load_text_embedding:
                                unseen_texts = [text for text in primitive_texts if text not in self.text_embedding_dict[person]]
                                if len(unseen_texts) > 0:
                                    self.update_text_embedding_dict(unseen_texts, person, text_sep=self.text_sep, max_segs=self.max_segs)
                                text_embedding = torch.stack([self.text_embedding_dict[person][text] for text in primitive_texts], dim=0)  # [B, 512]
                                if self.text_sep:
                                    text_mask = torch.stack([self.text_mask_dict[person][text] for text in primitive_texts], dim=0)  # [B, max_segs]
                                else:
                                    text_mask = None
                        gender_batch.append(
                            {
                                'gender': [gender_seq_dict[person]['gender']] * gender_batch_size,
                                'betas': gender_seq_dict[person]['betas'][start_idx:end_idx, :-1, :10],
                                'motion_tensor_normalized': motion_tensor_normalized[start_idx:end_idx, ...], # [B, D, 1, T]
                                'history_motion': history_motion[start_idx:end_idx, ...],
                                'history_mask': history_mask[start_idx:end_idx, ...],
                                'transf_rotmat': transf_rotmat[person][start_idx:end_idx, ...],
                                'transf_transl': transf_transl[person][start_idx:end_idx, ...],
                                'start_frame': gender_seq_dict[person]['start_frame'][start_idx:end_idx],
                                'total_frames': gender_seq_dict[person]['total_frames'][start_idx:end_idx],
                                'history_length': self.cfg.history_length,
                                'future_length': self.cfg.future_length,
                            }
                        )
                        if self.use_indi_text:
                            gender_batch[-1]['texts'] = primitive_texts
                            if self.load_text_embedding:
                                gender_batch[-1]['text_embedding'] = text_embedding
                                gender_batch[-1]['text_mask'] = text_mask
                        if self.padding:
                            gender_batch[-1]['primitive_padding_mask'] = gender_seq_dict[person]['primitive_padding_mask'][start_idx:end_idx, ...]
                selector = torch.cat([torch.ones(gender_batch_size), torch.zeros(gender_batch_size)])
                selector = selector[torch.randperm(2 * gender_batch_size)]
                
                front_group, back_group = {}, {}
                for key in add_key_list:
                    front_group[key], back_group[key] = [], []
                    for d in gender_batch[:self.num_primitive]:
                        front_group[key] += d[key]
                    for d in gender_batch[self.num_primitive:]:
                        back_group[key] += d[key]
                for key in cat_key_list:
                    front_group[key] = torch.cat([d[key] for d in gender_batch[:self.num_primitive]], dim=0)
                    back_group[key] = torch.cat([d[key] for d in gender_batch[self.num_primitive:]], dim=0)

                front_indices = torch.nonzero(selector[:gender_batch_size], as_tuple=True)[0]  
                back_indices = torch.nonzero(selector[gender_batch_size:], as_tuple=True)[0]  

                selected_batch = []
                for i in range(self.num_primitive):
                    selected_dict = {'history_length': self.cfg.history_length,
                                'future_length': self.cfg.future_length,}
                    for key in front_group.keys():    
                        if key in add_key_list:
                            selected_front = [front_group[key][i] for i in front_indices + i * gender_batch_size] 
                            selected_back = [back_group[key][i] for i in back_indices + i * gender_batch_size]
                            selected_dict[key] = selected_front + selected_back
                        elif key in cat_key_list:
                            selected_front = front_group[key][front_indices + i * gender_batch_size] 
                            selected_back = back_group[key][back_indices + i * gender_batch_size]
                            selected_dict[key] = torch.cat([selected_front, selected_back], dim=0)  
                    selected_batch.append(selected_dict)
                gender_batch = selected_batch
                            
            if batch is None:
                batch = gender_batch
            else:  # concatenate different gender batch
                for primitive_idx in range(self.num_primitive):
                    if self.mode == 'merged':
                        for key_type in self.key_list:
                            if key_type != 'interaction':
                                for key in add_key_list:
                                    batch[key_type][primitive_idx][key] = batch[key_type][primitive_idx][key] + gender_batch[key_type][primitive_idx][key]
                                for key in cat_key_list:
                                    batch[key_type][primitive_idx][key] = torch.cat([batch[key_type][primitive_idx][key], gender_batch[key_type][primitive_idx][key]], dim=0)
                            else:
                                for key in ['texts']:
                                    batch[key_type][primitive_idx][key] = batch[key_type][primitive_idx][key] + gender_batch[key_type][primitive_idx][key]
                                if self.load_text_embedding:
                                    batch[key_type][primitive_idx]['text_embedding'] = torch.cat([batch[key_type][primitive_idx]['text_embedding'], gender_batch[key_type][primitive_idx]['text_embedding']], dim=0)
                                    batch[key_type][primitive_idx]['text_mask'] = torch.cat([batch[key_type][primitive_idx]['text_mask'], gender_batch[key_type][primitive_idx]['text_mask']], dim=0)   
                    else:
                        for key in add_key_list:
                            batch[primitive_idx][key] = batch[primitive_idx][key] + gender_batch[primitive_idx][key]
                        for key in cat_key_list:
                            batch[primitive_idx][key] = torch.cat([batch[primitive_idx][key], gender_batch[primitive_idx][key]], dim=0)
            # if self.mode == 'merged':
            #     if random.random() < 0.5:
            #         batch['person1'], batch['person2'] = batch['person2'], batch['person1']

        return batch

# dataset = InterGenDatasetWPE(enforce_gender=None,
#                             enforce_zero_beta=0,
#                             device='cuda:7',
#                             mode='merged',
#                             text_encoder='clip',
#                             text_sep=False,
#                             split='train',
#                             use_interaction_model=False,
#                             padding=True,
#                             clip_version='ViT-L/14@336px',
#                             load_text_embedding=False,
#                             use_indi_text=False,)
# # dataset.calc_mean_std()
# batch = dataset.get_batch(4)
# # dataset.get_item(0)
