import time
import copy
import numpy as np
import os
from os.path import join as pjoin

import torch
import torch.nn as nn
import random
import json
import torch.multiprocessing as mp
from pytorch3d import transforms
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
from torch.utils.data._utils.collate import default_collate

from utils.smpl_utils import *
from utils.word_vectorizer import WordVectorizer, POS_enumerator

# get_data_interx.py
import torch
from torch.utils.data._utils.collate import default_collate

def collate_fn(batch):
    batch.sort(key=lambda x: x[3], reverse=True)
    return default_collate(batch)

def extract_replication(data, indices):
    if isinstance(indices, int):
        indices = [indices]
    
    extracted = {}
    for person in ['person1', 'person2']:
        extracted[person] = {}
        for key, value in data[person].items():
            if isinstance(value, list):
                extracted[person][key] = [value[i] for i in indices]
            elif isinstance(value, torch.Tensor):
                extracted[person][key] = value[indices]
            else:
                raise TypeError(f"Unsupported data type for key '{key}': {type(value)}")
    return extracted


class EvaluationDataset(Dataset):
    def __init__(self, 
                 args, 
                 denoiser_model, 
                 denoiser_args, 
                 diffusion, 
                 diffusion_args, 
                 vae_model, 
                 vae_args, 
                 data_args, 
                 ground_truth_dataset, 
                 device, 
                 mm_num_samples, 
                 mm_num_repeats,
                 load_data=True,
                 load_dir='',
                 w_vectorizer=None,
                 **kwargs,):
        self.args = args
        self.denoiser_model = denoiser_model
        self.denoiser_model.eval()
        self.denoiser_args = denoiser_args
        self.diffusion = diffusion
        self.diffusion_args = diffusion_args
        self.vae_model = vae_model
        self.vae_args = vae_args
        self.data_args = data_args
        self.dataset = ground_truth_dataset
        dataloader = DataLoader(ground_truth_dataset, batch_size=1, num_workers=0, shuffle=False)
        self.device = device
        self.mm_num_samples = mm_num_samples
        self.mm_num_repeats = mm_num_repeats
        self.replication_times = kwargs.get('replication_times', 1)
        self.replication = kwargs.get('replication', 0)
        # self.primitive_utility = self.dataset.primitive_utility
        self.primitive_utility = PrimitiveUtility(device=self.device, 
                                                  body_type=self.dataset.primitive_utility.body_type, 
                                                  motion_repr=self.dataset.primitive_utility.motion_repr)
        print('body_type:', self.primitive_utility.body_type)
        
        self.text_encoder = kwargs.get('text_encoder', None)
        self.text_encoder_version = getattr(denoiser_args, 'text_encoder_version', 'v1')
        
        self.w_vectorizer = w_vectorizer
        self.max_motion_length = self.dataset.max_motion_length

        if load_data:
            if args.load_from_file:
                with open(os.path.join(load_dir, "mm_indices.json")) as f:
                    mm_idxs = set(json.load(f))
            else:
                idxs = list(range(len(self.dataset)))
                random.shuffle(idxs)
                mm_idxs = idxs[:mm_num_samples]
            
            generated_motion = []
            mm_generated_motions = []

            with torch.no_grad():
                for i, data in enumerate(dataloader):
                    word_emb, pos_ohot, caption, cap_lens, motions, motion_lens, tokens = data
                    tokens = tokens[0].split('_')
                    seq_name = motions['seq_name'][0]
                    
                    if args.load_from_file:
                        path = os.path.join(load_dir, f"motion_{seq_name}.pt")
                        motion_sequences = torch.load(path)
                        if self.replication_times > 1:
                            if i in mm_idxs:
                                motion_sequences = extract_replication(motion_sequences, range(self.replication*self.mm_num_repeats,(self.replication+1) * self.mm_num_repeats))
                            else:
                                motion_sequences = extract_replication(motion_sequences, self.replication)
                    else:
                        if i in mm_idxs:
                            motion_sequences = self.generate_motion(motions, caption, motion_lens, mm_num_repeats)
                        else:
                            motion_sequences = self.generate_motion(motions, caption, motion_lens, 1)

                    first_batch_motion = {}
                    if self.data_args.interaction:
                        for person in ['person1', 'person2']:
                            first_batch_motion[person] = {}
                            for key, value in motion_sequences[person].items():
                                if isinstance(value, torch.Tensor):
                                    first_batch_motion[person][key] = value[0]
                                elif isinstance(value, list):
                                    first_batch_motion[person][key] = [value[0]]
                                else:
                                    first_batch_motion[person][key] = value
                    else:
                        for key, value in motion_sequences.items():
                            if isinstance(value, torch.Tensor):
                                first_batch_motion[key] = value[0]
                            elif isinstance(value, list):
                                first_batch_motion[key] = [value[0]]
                            else:
                                first_batch_motion[key] = value
                    
                    sub_dict = {'motion': first_batch_motion,
                                'length': motion_lens[0].item(),
                                'cap_len': cap_lens[0].item(),
                                'caption': caption[0],
                                'tokens': tokens}
                    generated_motion.append(sub_dict)
                
                    if seq_name in mm_idxs:
                        mm_sub_dict = {'caption': caption[0],
                                        'tokens': tokens,
                                        'cap_len': cap_lens[0].item(),
                                        'length': [motion_lens[0].item()] * self.mm_num_repeats,
                                        'mm_motions': motion_sequences}
                        mm_generated_motions.append(mm_sub_dict)

            self.generated_motion = generated_motion
            self.mm_generated_motions = mm_generated_motions

    def preprocess(self, data, caption):
        # Canonicalization
        canonicalized_primitive_dict, transf_rotmat, transf_transl = {}, {}, {}
        for person in ['person1', 'person2']:
            transf_rotmat[person], transf_transl[person], canonicalized_primitive_dict[person] = self.primitive_utility.canonicalize(copy.deepcopy(data[person]), use_predicted_joints=True)
                    
        # rel_pose
        rel_rotmat, rel_transl, rel_pose = {}, {}, {}
        rel_rotmat['b2a'], rel_transl['b2a'] = self.primitive_utility.compute_rel_transform_B_in_A(transf_rotmat['person1'], transf_transl['person1'], transf_rotmat['person2'], transf_transl['person2'])
        rel_rotmat['a2b'], rel_transl['a2b'] = self.primitive_utility.compute_rel_transform_B_in_A(transf_rotmat['person2'], transf_transl['person2'], transf_rotmat['person1'], transf_transl['person1'])
        for rel in ['b2a', 'a2b']:
            rel_rotmat[rel] = transforms.matrix_to_rotation_6d(rel_rotmat[rel])
            rel_pose[rel] = torch.cat([rel_rotmat[rel], rel_transl[rel].squeeze(1)], dim=-1)  # [1, 6+3]
        
        data_batch = {}
        # calculate features
        feature_dict = {}
        for person in ['person1', 'person2']:
            feature_dict[person] = self.primitive_utility.calc_features(canonicalized_primitive_dict[person], use_predicted_joints=True)
            if self.primitive_utility.feature_dim == 276:
                feature_dict[person]['transl'] = feature_dict[person]['transl'][:, :-1, :]  # [1, T, 3]
                feature_dict[person]['poses_6d'] = feature_dict[person]['poses_6d'][:, :-1, :]  # [1, T, 66]
                feature_dict[person]['joints'] = feature_dict[person]['joints'][:, :-1, :]  # [1, T, 22 * 3]
            elif self.primitive_utility.feature_dim == 56 * 6:
                feature_dict[person]['transl'] = feature_dict[person]['transl'][:, :-1, :]  # [B*num_mp, T, 3]
                feature_dict[person]['body_pose'] = feature_dict[person]['body_pose'][:, :-1, :]  # [B*num_mp, T, 55*6]
            elif self.primitive_utility.feature_dim == 55 * 12:
                feature_dict[person]['joints'] = feature_dict[person]['joints'][:, :-1, :]  # [B*num_mp, T, 55 * 3]
                feature_dict[person]['body_pose'] = feature_dict[person]['body_pose'][:, :-1, :]  # [B*num_mp, T, 55*6]   

        for person in ['person1', 'person2']:
            motion_tensor_normalized = self.dataset.normalize(self.primitive_utility.dict_to_tensor(feature_dict[person]))      # [1, T, D]
            motion_tensor_normalized = motion_tensor_normalized.permute(0, 2, 1).unsqueeze(2)                           # [1, D, 1, T]
            
            if self.dataset.use_indi_text:    
                texts = [random.choice([itext['proc_label'] for itext in data[f'frame_labels_{person}']])]
                                            
            data_batch[person] = {
                    'gender': data[person]['gender'],
                    'betas': data[person]['betas'],
                    'motion_tensor_normalized': motion_tensor_normalized,        # [1, D, 1, T]
                    'transf_rotmat': transf_rotmat[person],
                    'transf_transl': transf_transl[person],
                    'history_length': self.dataset.history_length,
                    'future_length': self.dataset.future_length,
                }
            if self.dataset.use_indi_text:
                data_batch[person]['texts'] = texts[0]
                if self.dataset.load_text_embedding:
                    data_batch[person]['text_embedding'] = text_embedding.detach().cpu()
                    if self.dataset.text_sep:
                        data_batch[person]['text_mask'] = text_mask.detach().cpu()
        texts = [caption[0]]
        # unseen_texts = [text for text in texts if text not in self.text_embedding_dict['interaction']]
        # if len(unseen_texts) > 0:
        #     self.update_text_embedding_dict(unseen_texts, 'interaction', text_sep=self.text_sep, max_segs=self.max_segs)

        data_batch['interaction'] = {
            'texts': texts, 
            'rel_pose_b2a': rel_pose['b2a'],
            'rel_pose_a2b': rel_pose['a2b'],
        }
        if self.dataset.load_text_embedding:
            text_embedding = torch.stack([self.dataset.text_embedding_dict['interaction'][text] for text in texts], dim=0)  # [1, 512]
            text_mask = torch.stack([self.dataset.text_mask_dict['interaction'][text] for text in texts], dim=0) if self.dataset.text_sep else None
            data_batch['interaction']['text_embedding'] = text_embedding.detach().cpu()
            if self.dataset.text_sep:
                data_batch['interaction']['text_mask'] = text_mask.detach().cpu()
        
        return data_batch
    
    def generate_motion(self, batch_data, interaction_text, motion_lens, batch_size):
        """
        Generate motion for a given batch of data.
        :param batch_data: A dictionary containing the input data for the model.
        :return: A dictionary containing the generated motion and other relevant information.
        """
        device = self.device
        future_length = self.dataset.future_length
        history_length = self.dataset.history_length
        primitive_length = history_length + future_length
        sample_fn = self.diffusion.p_sample_loop if self.diffusion_args.respacing == '' else self.diffusion.ddim_sample_loop
        
        batch_data = self.preprocess(batch_data, interaction_text)

        lengths = motion_lens
        
        def expand_data(data, batch_size):
            for key in data.keys():
                if key in ['texts', 'gender']:
                    data[key] *= batch_size
                elif key in ['text_embedding', 'text_mask', 'betas', 'motion_tensor_normalized', 'transf_rotmat', 'transf_transl', 'rel_pose_b2a', 'rel_pose_a2b']:
                    data[key] = data[key].repeat(batch_size, *([1] * (data[key].dim() - 1)))
                else:
                    continue
        
        if self.data_args.interaction:
            for person in ['person1', 'person2', 'interaction']:
                expand_data(batch_data[person], batch_size)
        else:
            expand_data(batch_data, batch_size)
        
        if self.data_args.interaction:
            input_motions, model_kwargs, gender, betas, pelvis_delta, motion_tensor, history_motion_gt = {}, {}, {}, {}, {}, {}, {}
            for person in ['person1', 'person2']:
                input_motions[person] = batch_data[person]['motion_tensor_normalized']
                model_kwargs[person] = {'y': batch_data[person]}
                del model_kwargs[person]['y']['motion_tensor_normalized']
                gender[person] = model_kwargs[person]['y']['gender']
                betas[person] = model_kwargs[person]['y']['betas'][:, :primitive_length, :10].to(device) # [B, H+F, 10]
                pelvis_delta[person] = self.primitive_utility.calc_calibrate_offset({
                    'betas': betas[person][:, 0, :],
                    'gender': gender[person],
                })
                input_motions[person] = input_motions[person].to(device)  # [B, D, 1, T]
                motion_tensor[person] = input_motions[person].squeeze(2).permute(0, 2, 1)  # [B, T, D]
                history_motion_gt[person] = motion_tensor[person][:, :history_length, :]  # [B, H, D]
            interaction = batch_data['interaction']
        else:
            input_motions, model_kwargs = batch_data['motion_tensor_normalized'], {'y': batch_data}
            del model_kwargs['y']['motion_tensor_normalized']
            gender = model_kwargs['y']['gender']
            betas = model_kwargs['y']['betas'][:, :primitive_length, :].to(device)  # [B, H+F, 10]
            pelvis_delta = self.primitive_utility.calc_calibrate_offset({
                'betas': betas[:, 0, :],
                'gender': gender,
            })
            input_motions = input_motions.to(device)  # [B, D, 1, T]
            motion_tensor = input_motions.squeeze(2).permute(0, 2, 1)  # [B, T, D]
            history_motion_gt = motion_tensor[:, :history_length, :]  # [B, H, D]

        motion_sequences = None if not self.data_args.interaction else {'person1': None, 'person2': None}
        history_motion = history_motion_gt
        transf_rotmat = {
            'person1': batch_data['person1']['transf_rotmat'].to(device),
            'person2': batch_data['person2']['transf_rotmat'].to(device),
        } if self.data_args.interaction else batch_data['transf_rotmat'].to(device)
        transf_transl = {
            'person1': batch_data['person1']['transf_transl'].to(device),
            'person2': batch_data['person2']['transf_transl'].to(device),
        } if self.data_args.interaction else batch_data['transf_transl'].to(device)
        
        if self.args.fix_floor:
            if self.data_args.interaction:
                motion_dict = {}
                for person in ['person1', 'person2']:
                    motion_dict[person] = self.primitive_utility.tensor_to_dict(self.dataset.denormalize(history_motion_gt[person]))
                    joints = motion_dict[person]['joints'].reshape(batch_size, history_length, 22, 3)  # [B, T, 22, 3]   
                    init_floor_height = joints[:, 0, :, 2].amin(dim=-1)  # [B]
                    transf_transl[person][:, :, 2] = -init_floor_height.unsqueeze(-1)
            else:
                motion_dict = self.primitive_utility.tensor_to_dict(self.dataset.denormalize(history_motion_gt))
                joints = motion_dict['joints'].reshape(batch_size, history_length, 22, 3)  # [B, T, 22, 3]
                init_floor_height = joints[:, 0, :, 2].amin(dim=-1)  # [B]
                transf_transl[:, :, 2] = -init_floor_height.unsqueeze(-1)

        num_primitives = int(np.ceil((lengths-history_length) / future_length))
        guidance_param = torch.ones(batch_size, *self.denoiser_args.model_args.noise_shape).to(device=self.device) * self.args.guidance_param
        if self.denoiser_args.use_pre_latent:
            pre_latent = [] if not self.data_args.interaction else {'person1': [], 'person2': []}
            pre_transf_rotmat_abs = [] if not self.data_args.interaction else {'person1': [], 'person2': []}
            pre_transf_transl_abs = [] if not self.data_args.interaction else {'person1': [], 'person2': []}

        for primitive_id in range(num_primitives):
            valid_length = min(future_length, lengths - (primitive_id * future_length+history_length))
            if self.data_args.interaction:
                if self.denoiser_args.merge_his_relpose:
                    history_motion_rel = {}
                    for person in ['person1', 'person2']:
                        history_motion_rel[person] = copy.deepcopy(history_motion[person])
                        his_motion_rel_denormalized = self.dataset.denormalize(history_motion_rel[person])
                        his_motion_rel = self.primitive_utility.relative_transform_feature_tensor(
                            his_motion_rel_denormalized.to(device),
                            transforms.rotation_6d_to_matrix(interaction['rel_pose_a2b' if person == 'person1' else 'rel_pose_b2a'][:, :6]).to(device),
                            interaction['rel_pose_a2b' if person == 'person1' else 'rel_pose_b2a'][:, 6:9].unsqueeze(1).to(device),
                            batch_data[person]['gender'],
                            batch_data[person]['betas'][:, 0].to(device),
                        )
                        history_motion_rel[person] = self.dataset.normalize(his_motion_rel)  # [B, H, D]

                latent_pred, future_motion_pred, future_frames, all_frames, valid_future_frames, new_history_frames = {}, {}, {}, {}, {}, {}
                for person in ['person1', 'person2']:
                    y = {
                        'history_motion_normalized': history_motion[person],
                        'history_motion_normalized_b': history_motion['person2' if person == 'person1' else 'person1'],
                        'text_inter': batch_data['interaction']['texts'],
                        'rel_pose': interaction['rel_pose_'+'b2a' if person == 'person1' else 'rel_pose_'+'a2b'].to(device),
                        'scale': guidance_param,
                    }
                    if self.denoiser_args.load_text_embedding:
                        if self.denoiser_args.use_indi_text:
                            y['text_embedding'] = batch_data[person]['text_embedding'].to(device)
                        y['text_embedding_inter'] = batch_data['interaction']['text_embedding'].to(device)
                    else:
                        if self.text_encoder_version == 'v1':
                            if self.denoiser_args.use_indi_text:
                                y['text_embedding'] = self.text_encoder(batch_data[person]['texts']).to(device)
                            y['text_embedding_inter'] = self.text_encoder(batch_data['interaction']['texts']).to(device)
                        elif self.text_encoder_version in ['v2', 'v3']:
                            if self.denoiser_args.use_indi_text:
                                y['text_embedding'], batch_data[person]['text_mask'] = self.text_encoder(batch_data[person]['texts'])
                                y['text_embedding'] = y['text_embedding'].to(device)
                            y['text_embedding_inter'], batch_data['interaction']['text_mask'] = self.text_encoder(batch_data['interaction']['texts'])
                            y['text_embedding_inter'] = y['text_embedding_inter'].to(device)
                    if self.denoiser_args.merge_his_relpose:
                        y['history_motion_normalized_b'] = history_motion_rel['person2' if person == 'person1' else 'person1']
                    if self.denoiser_args.text_sep:
                        if self.denoiser_args.use_indi_text:
                            y['text_mask'] = batch_data[person]['text_mask'].to(device)
                        y['text_mask_inter'] = batch_data['interaction']['text_mask'].to(device)
                    if self.denoiser_args.use_pre_latent:
                        y['pre_latent'] = torch.cat(pre_latent[person], dim=1) if len(pre_latent[person])!=0 else None
                        if len(pre_transf_rotmat_abs[person])==0:
                            y['pre_reltrans'] = None
                        else:
                            y['pre_reltrans'] = []
                            for transf_rotmat_abs, transf_transl_abs in zip(pre_transf_rotmat_abs[person], pre_transf_transl_abs[person]):
                                rel_rotmat, rel_transl = self.primitive_utility.compute_rel_transform_B_in_A(
                                    transf_rotmat[person], transf_transl[person], transf_rotmat_abs, transf_transl_abs)
                                y['pre_reltrans'].append(torch.cat([transforms.matrix_to_rotation_6d(rel_rotmat), rel_transl.squeeze(1)], dim=-1).unsqueeze(1))
                            y['pre_reltrans'] = torch.cat(y['pre_reltrans'], dim=1)  # [B, num_primitive, 6+3]

                    x_start_pred = sample_fn(
                        self.denoiser_model,
                        (batch_size, *self.denoiser_args.model_args.noise_shape),
                        clip_denoised=False,
                        model_kwargs={'y': y},
                        skip_timesteps=0,  # 0 is the default value - i.e. don't skip any step
                        init_image=None,
                        progress=False,
                        dump_steps=None,
                        noise=torch.zeros_like(guidance_param) if self.args.zero_noise else None,
                        const_noise=False,
                    )  # [B, T=1, D]
                    latent_pred[person] = x_start_pred.permute(1, 0, 2)  # [T=1, B, D]
                    future_motion_pred[person] = self.vae_model.decode(latent_pred[person], history_motion[person], nfuture=future_length,
                                                                scale_latent=self.denoiser_args.rescale_latent)
                    future_frames[person] = self.dataset.denormalize(future_motion_pred[person])
                    future_start, future_end = 0, valid_length
                    valid_future_frames[person] = future_frames[person][:, future_start:future_end, :]  # ignore the initial standing seed
                    all_frames[person] = torch.cat([self.dataset.denormalize(valid_future_frames[person]), future_frames[person]], dim=1)
                    new_history_end = history_length + valid_length
                    new_history_start = new_history_end - history_length
                    new_history_frames[person] = all_frames[person][:, new_history_start:new_history_end, :]
                    
                    if self.denoiser_args.use_pre_latent:
                        """store pre latent and global transform"""
                        pre_latent[person].append(latent_pred[person].permute(1, 0, 2))
                        pre_transf_rotmat_abs[person].append(transf_rotmat[person])
                        pre_transf_transl_abs[person].append(transf_transl[person])
                    
                    """transform primitive to world coordinate, prepare for serialization"""
                    if motion_sequences[person] is None:
                        all_feature_dict = self.primitive_utility.tensor_to_dict(all_frames[person])
                        all_feature_dict.update(
                            {
                                'transf_rotmat': transf_rotmat[person],
                                'transf_transl': transf_transl[person],
                                'gender': gender[person],
                                'betas': betas[person][:, :history_length+future_end, :],
                                'pelvis_delta': pelvis_delta[person],
                            }
                        )
                        all_primitive_dict = self.primitive_utility.feature_dict_to_smpl_dict(all_feature_dict)
                        all_primitive_dict = self.primitive_utility.transform_primitive_to_world(all_primitive_dict)
                        motion_sequences[person] = all_primitive_dict
                    else:
                        future_feature_dict = self.primitive_utility.tensor_to_dict(valid_future_frames[person])
                        future_feature_dict.update(
                            {
                                'transf_rotmat': transf_rotmat[person],
                                'transf_transl': transf_transl[person],
                                'gender': gender[person],
                                'betas': betas[person][:, future_start:future_end, :],
                                'pelvis_delta': pelvis_delta[person],
                            }
                        )
                        future_primitive_dict = self.primitive_utility.feature_dict_to_smpl_dict(future_feature_dict)
                        future_primitive_dict = self.primitive_utility.transform_primitive_to_world(future_primitive_dict)
                        for key in motion_sequences[person].keys():
                            if key in ['transl', 'global_orient', 'body_pose', 
                                       'jaw_pose', 'left_eye_pose', 'right_eye_pose', 'left_hand_pose', 'right_hand_pose', 
                                       'betas', 'joints']:
                                motion_sequences[person][key] = torch.cat([motion_sequences[person][key], future_primitive_dict[key]], dim=1)  # [B, T, ...]

                    """update history motion seed, update global transform"""
                    if primitive_id < num_primitives - 1:
                        history_feature_dict = self.primitive_utility.tensor_to_dict(new_history_frames[person])
                        history_feature_dict.update(
                            {
                                'transf_rotmat': transf_rotmat[person],
                                'transf_transl': transf_transl[person],
                                'gender': gender[person],
                                'betas': betas[person][:, new_history_start:new_history_end, :],
                                'pelvis_delta': pelvis_delta[person],
                            }
                        )
                        if self.args.fix_floor and primitive_id == num_primitives - 1:  # fix the first frame feet of each segment to be on floor
                            foot_height = history_feature_dict['joints'].reshape(-1, history_length, 22, 3)[:, 0, FOOT_JOINTS_IDX, 2].amin(dim=-1)  # [B]
                            foot_height_world = foot_height + history_feature_dict['transf_transl'][:, 0, 2]  # [B]
                            history_feature_dict['transf_transl'][:, 0, 2] -= foot_height_world
                        canonicalized_history_primitive_dict, blended_feature_dict = self.primitive_utility.get_blended_feature(
                            history_feature_dict, use_predicted_joints=self.args.use_predicted_joints)
                        transf_rotmat[person], transf_transl[person] = canonicalized_history_primitive_dict['transf_rotmat'], \
                        canonicalized_history_primitive_dict['transf_transl']
                        history_motion[person] = self.primitive_utility.dict_to_tensor(blended_feature_dict)
                        history_motion[person] = self.dataset.normalize(history_motion[person])  # [B, T, D]
                if primitive_id < num_primitives - 1:
                    rel_rotmat, rel_transl, rel_pose = {}, {}, {}
                    rel_rotmat['b2a'], rel_transl['b2a'] = self.primitive_utility.compute_rel_transform_B_in_A(
                        transf_rotmat['person1'], transf_transl['person1'],transf_rotmat['person2'], transf_transl['person2'])
                    rel_rotmat['a2b'], rel_transl['a2b'] = self.primitive_utility.compute_rel_transform_B_in_A(
                        transf_rotmat['person2'], transf_transl['person2'],transf_rotmat['person1'], transf_transl['person1'])
                    for rel in ['b2a', 'a2b']:
                        rel_rotmat[rel] = transforms.matrix_to_rotation_6d(rel_rotmat[rel])
                        rel_pose[rel] = torch.cat([rel_rotmat[rel], rel_transl[rel].squeeze(1)], dim=-1)  # [B*num_mp, 6+3]
                    interaction['rel_pose_b2a'] = rel_pose['b2a']
                    interaction['rel_pose_a2b'] = rel_pose['a2b']
            else:
                y = {
                    'history_motion_normalized': history_motion,
                    'scale': guidance_param,
                }
                if self.denoiser_args.load_text_embedding:
                    y['text_embedding'] = batch_data['text_embedding'].to(device)
                else:
                    if self.text_encoder_version == 'v1':
                        y['text_embedding'] = self.text_encoder(batch_data['texts']).to(device)
                    elif self.text_encoder_version in ['v2', 'v3']:
                        y['text_embedding'], batch_data['text_mask'] = self.text_encoder(batch_data['texts'])
                        y['text_embedding'] = y['text_embedding'].to(device)
                if self.denoiser_args.text_sep:
                    y['text_mask'] = batch_data['text_mask'].to(device)

                x_start_pred = sample_fn(
                    self.denoiser_model,
                    (batch_size, *self.denoiser_args.model_args.noise_shape),
                    clip_denoised=False,
                    model_kwargs={'y': y},
                    skip_timesteps=0,  # 0 is the default value - i.e. don't skip any step
                    init_image=None,
                    progress=False,
                    dump_steps=None,
                    noise=torch.zeros_like(guidance_param) if self.args.zero_noise else None,
                    const_noise=False,
                )  # [B, T=1, D]
                latent_pred = x_start_pred.permute(1, 0, 2)  # [T=1, B, D]
                future_motion_pred = self.vae_model.decode(latent_pred, history_motion, nfuture=future_length,
                                                    scale_latent=self.denoiser_args.rescale_latent)  # [B, F, D], normalized
                future_frames = self.dataset.denormalize(future_motion_pred)
                all_frames = torch.cat([self.dataset.denormalize(history_motion), future_frames], dim=1)
                
                future_start, future_end = 0, valid_length
                valid_future_frames = future_frames[:, future_start:future_end, :]  # ignore the initial standing seed
                new_history_end = history_length + valid_length
                new_history_start = new_history_end - history_length
                new_history_frames = all_frames[:, new_history_start:new_history_end, :]
                
                """transform primitive to world coordinate, prepare for serialization"""
                future_feature_dict = self.primitive_utility.tensor_to_dict(valid_future_frames)
                future_feature_dict.update(
                    {
                        'transf_rotmat': transf_rotmat,
                        'transf_transl': transf_transl,
                        'gender': gender,
                        'betas': betas[:, future_start:future_end, :],
                        'pelvis_delta': pelvis_delta,
                    }
                )
                future_primitive_dict = self.primitive_utility.feature_dict_to_smpl_dict(future_feature_dict)
                future_primitive_dict = self.primitive_utility.transform_primitive_to_world(future_primitive_dict)
                if motion_sequences is None:
                    motion_sequences = future_primitive_dict
                else:
                    for key in motion_sequences.keys():
                        if key in ['transl', 'global_orient', 'body_pose', 'betas', 'joints']:
                            motion_sequences[key] = torch.cat([motion_sequences[key], future_primitive_dict[key]], dim=1)  # [B, T, ...]

                """update history motion seed, update global transform"""
                history_feature_dict = self.primitive_utility.tensor_to_dict(new_history_frames)
                history_feature_dict.update(
                    {
                        'transf_rotmat': transf_rotmat,
                        'transf_transl': transf_transl,
                        'gender': gender,
                        'betas': betas[:, new_history_start:new_history_end, :],
                        'pelvis_delta': pelvis_delta,
                    }
                )
                if self.args.fix_floor and primitive_id == num_primitives - 1:  # fix the first frame feet of each segment to be on floor
                    foot_height = history_feature_dict['joints'].reshape(-1, history_length, 22, 3)[:, 0, FOOT_JOINTS_IDX, 2].amin(dim=-1)  # [B]
                    foot_height_world = foot_height + history_feature_dict['transf_transl'][:, 0, 2]  # [B]
                    history_feature_dict['transf_transl'][:, 0, 2] -= foot_height_world
                canonicalized_history_primitive_dict, blended_feature_dict = self.primitive_utility.get_blended_feature(
                    history_feature_dict, use_predicted_joints=self.args.use_predicted_joints)
                transf_rotmat, transf_transl = canonicalized_history_primitive_dict['transf_rotmat'], \
                canonicalized_history_primitive_dict['transf_transl']
                history_motion = self.primitive_utility.dict_to_tensor(blended_feature_dict)
                history_motion = self.dataset.normalize(history_motion)  # [B, T, D]
        def add_text(motion_sequences, texts):
            motion_sequences['texts'] = texts
            
        if self.data_args.interaction:
            for person in ['person1', 'person2']:
                add_text(motion_sequences[person], batch_data['interaction']['texts'])
        else:
            add_text(motion_sequences, batch_data['texts'])
        
        # padding
        def pad_motion(motion_sequences, padding_length):
            for key in motion_sequences.keys():
                if key in ['betas', 'transl', 'joints', 'global_orient', 'body_pose', 
                           'jaw_pose', 'left_eye_pose', 'right_eye_pose', 'left_hand_pose', 'right_hand_pose']:
                    last_frame = motion_sequences[key][:, -1:]
                    padding = last_frame.repeat(1, padding_length, *([1] * (motion_sequences[key].dim() - 2)))
                    motion_sequences[key] = torch.cat([motion_sequences[key], padding], dim=1)
                else:
                    continue
            return motion_sequences
        
        if lengths < self.max_motion_length + 1:
            padding_length = self.max_motion_length + 1 - lengths
            if self.data_args.interaction:
                for person in ['person1', 'person2']:
                    motion_sequences[person] = pad_motion(motion_sequences[person], padding_length)
            else:
                motion_sequences = pad_motion(motion_sequences, padding_length)
        
        # for person in ['person1', 'person2']:
        #     for key in motion_sequences[person].keys():
        #         if key in ['betas', 'transl', 'global_orient', 'body_pose', 'pelvis_delta', 'joints', 'transf_rotmat', 'transf_transl']:
        #             motion_sequences[person][key] = motion_sequences[person][key].squeeze(0)   
        
        if self.data_args.interaction:
            for person in ['person1', 'person2']:
                tensor_dict_to_device(motion_sequences[person], 'cpu')
        else:
            tensor_dict_to_device(motion_sequences, 'cpu')
        tensor_dict_to_device(motion_sequences, 'cpu')
        return motion_sequences
    
    def __len__(self):
        return len(self.generated_motion)

    def __getitem__(self, item):
        data = self.generated_motion[item]
        motion, m_length, caption, tokens = data['motion'], data['length'], data['caption'], data['tokens']
        sent_len = data['cap_len']
        pos_one_hots = []
        word_embeddings = []
        for token in tokens:
            try:
                word_emb, pos_oh = self.w_vectorizer[token]
            except:
                word_emb, pos_oh = self.w_vectorizer['unk/OTHER']
            pos_one_hots.append(pos_oh[None, :])
            word_embeddings.append(word_emb[None, :])
        pos_one_hots = np.concatenate(pos_one_hots, axis=0)
        word_embeddings = np.concatenate(word_embeddings, axis=0)
        return word_embeddings, pos_one_hots, caption, sent_len, motion, m_length, '_'.join(tokens)


class MMGeneratedDataset(Dataset):
    def __init__(self,  motion_dataset, w_vectorizer):
        self.max_motion_length = motion_dataset.max_motion_length
        self.dataset = motion_dataset.mm_generated_motions
        self.w_vectorizer = w_vectorizer

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, item):
        data = self.dataset[item]
        mm_motions = data['mm_motions']
        m_lens = np.array(data['length'], dtype=np.int)
        return mm_motions, m_lens


def get_motion_loader(args, batch_size, model, denoiser_args, diffusion, diffusion_args, vae_model, vae_args, data_args, ground_truth_dataset, device, mm_num_samples, mm_num_repeats,**kwargs):
    # Currently the configurations of two datasets are almost the same
    replication = kwargs.get('replication', 0)
    replication_times = kwargs.get('replication_times', 1)
    text_encoder = kwargs.get('text_encoder', None)
    w_vectorizer = WordVectorizer(pjoin(data_args.data_dir.rsplit('/', 1)[0] + '/' + 'inter-x', 'glove'), 'hhi_vab')
    if replication_times > 1:
        load_dir = os.path.join(args.load_dir, f"replications_{replication_times}")
        if not os.path.exists(load_dir) or len([f for f in os.listdir(load_dir) if f.endswith('.pt')])!=len(ground_truth_dataset):
            launch_generation_on_multi_gpus(
                args, model, denoiser_args, diffusion_args, vae_model, vae_args, data_args, ground_truth_dataset, 
                mm_num_samples, mm_num_repeats, load_dir, text_encoder=text_encoder,
            )
    else:
        load_dir = os.path.join(args.load_dir, f"replication_{replication}")
        if not os.path.exists(load_dir) or len([f for f in os.listdir(load_dir) if f.endswith('.pt')])!=len(ground_truth_dataset):
            launch_generation_on_multi_gpus(
                args, model, denoiser_args, diffusion_args, vae_model, vae_args, data_args, ground_truth_dataset, 
                mm_num_samples, mm_num_repeats, load_dir, text_encoder=text_encoder, 
            )
    dataset = EvaluationDataset(args, model, denoiser_args, diffusion, diffusion_args, vae_model, vae_args, data_args, 
                                ground_truth_dataset, device, mm_num_samples, mm_num_repeats,
                                load_dir=load_dir,
                                w_vectorizer=w_vectorizer,
                                replication_times=replication_times, replication=replication,
                                )
    mm_dataset = MMGeneratedDataset(dataset, w_vectorizer)

    motion_loader = DataLoader(dataset, batch_size=batch_size, drop_last=True, num_workers=4, collate_fn=collate_fn)
    mm_motion_loader = DataLoader(mm_dataset, batch_size=1, num_workers=0)
    
    return motion_loader, mm_motion_loader


def get_dataset_motion_loader(data_args, batch_size, device, split, mode, text_sep, cut_length=0, **kwargs):
    if data_args.dataset == 'interx':
        print('Loading dataset %s ...' % data_args.dataset)
        from data_loaders.HHI.data.dataset_interx import InterXDatasetEval
        clip_version = kwargs.get('clip_version', 'ViT-B/32')
        load_text_embedding = kwargs.get('load_text_embedding', True)
        use_indi_text = kwargs.get('use_indi_text', True)
        w_vectorizer = WordVectorizer(pjoin(data_args.data_dir.rsplit('/', 1)[0] + '/' + 'inter-x', 'glove'), 'hhi_vab')
        dataset = InterXDatasetEval(
            dataset_name=data_args.dataset,
            dataset_path=data_args.data_dir,
            cfg_path=data_args.cfg_path,
            split=split,
            device=device,
            enforce_gender=data_args.enforce_gender,
            enforce_zero_beta=data_args.enforce_zero_beta,
            body_type=data_args.body_type,
            text_sep=text_sep,
            max_length=data_args.max_length,
            max_text_len=data_args.max_text_len,
            unit_length=data_args.unit_length,
            padding=data_args.padding,
            w_vectorizer=w_vectorizer,
            clip_version=clip_version,
            load_text_embedding=load_text_embedding,
            use_indi_text=use_indi_text,)
        dataloader = DataLoader(dataset, batch_size=batch_size, 
                                 num_workers=4, drop_last=True, shuffle=True, collate_fn=collate_fn)

    else:
        raise KeyError('Dataset not Recognized !!')

    print('Ground Truth Dataset Loading Completed!!!')
    return dataloader, dataset


def generate_worker(rank, args, all_sample_indices, 
                    model, denoiser_args, diffusion_args, vae_model, vae_args, data_args, ground_truth_dataset, 
                    mm_num_samples, mm_num_repeats, mm_sample_indices, save_dir, replication_times=1, text_encoder=None):
    torch.set_num_threads(4)
    torch.set_num_interop_threads(4)
    torch.cuda.set_device(rank)
    
    sample_indices = all_sample_indices[rank]
    from mld.train_mld_new_v9 import create_gaussian_diffusion
    diffusion = create_gaussian_diffusion(diffusion_args)
    dataset = EvaluationDataset(
        args=args,
        denoiser_model=model.to(rank),
        denoiser_args=denoiser_args,
        diffusion=diffusion,
        diffusion_args=diffusion_args,
        vae_model=vae_model.to(rank),
        vae_args=vae_args,
        data_args=data_args,
        ground_truth_dataset=ground_truth_dataset,
        device=torch.device(f"cuda:{rank}"),
        mm_num_samples=mm_num_samples,
        mm_num_repeats=mm_num_repeats,
        load_data=False,
        text_encoder=text_encoder,
    )

    dataloader = DataLoader(
        ground_truth_dataset, batch_size=1, shuffle=False, num_workers=0
    )

    with torch.no_grad():
        for i, data in enumerate(dataloader):
            word_emb, pos_ohot, caption, cap_lens, motions, motion_lens, tokens = data
            if os.path.exists(f"{save_dir}/motion_{motions['seq_name'][0]}.pt") and (motions['seq_name'][0] not in mm_sample_indices):
                continue
            if motions['seq_name'][0] not in sample_indices:
                continue

            is_mm = motions['seq_name'][0] in mm_sample_indices
            motion_sequences = dataset.generate_motion(motions, caption, motion_lens, mm_num_repeats if is_mm else replication_times)

            torch.save(motion_sequences, f"{save_dir}/motion_{motions['seq_name'][0]}.pt")

def launch_generation_on_multi_gpus(args, model, denoiser_args, diffusion_args, vae_model, vae_args, data_args, ground_truth_dataset, 
                    mm_num_samples, mm_num_repeats, save_dir, **kwargs):
    num_gpus = torch.cuda.device_count()
    real_indices = [ground_truth_dataset[i][-3]['seq_name'] for i in range(len(ground_truth_dataset))]
    random.shuffle(real_indices)
    mm_sample_indices = set(real_indices[:mm_num_samples])
    
    replication_times = kwargs.get('replication_times', 1)
    text_encoder = kwargs.get('text_encoder', None)
    
    if not os.path.exists(save_dir):
        os.makedirs(save_dir, exist_ok=True)
    with open(os.path.join(save_dir, "mm_indices.json"), 'w') as f:
        json.dump(sorted(list(mm_sample_indices)), f)
    

    per_gpu_indices = [[] for _ in range(num_gpus)]
    for idx, sample_id in enumerate(real_indices):
        per_gpu_indices[idx % num_gpus].append(sample_id)

    mp.spawn(
        generate_worker,
        args=(args, per_gpu_indices, model, denoiser_args, diffusion_args, vae_model, vae_args, data_args, ground_truth_dataset, 
                    mm_num_samples, mm_num_repeats, mm_sample_indices, save_dir, replication_times, text_encoder),
        nprocs=num_gpus,
        join=True
    )


def init_weight(m):
    if isinstance(m, nn.Conv1d) or isinstance(m, nn.Linear) or isinstance(m, nn.ConvTranspose1d):
        nn.init.xavier_normal_(m.weight)
        # m.bias.data.fill_(0.01)
        if m.bias is not None:
            nn.init.constant_(m.bias, 0)

class MovementConvEncoder(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(MovementConvEncoder, self).__init__()
        self.main = nn.Sequential(
            nn.Conv1d(input_size, hidden_size, 4, 2, 1),
            nn.Dropout(0.2, inplace=True),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv1d(hidden_size, output_size, 4, 2, 1),
            nn.Dropout(0.2, inplace=True),
            nn.LeakyReLU(0.2, inplace=True),
        )
        self.out_net = nn.Linear(output_size, output_size)
        self.main.apply(init_weight)
        self.out_net.apply(init_weight)

    def forward(self, inputs):
        inputs = inputs.permute(0, 2, 1)
        outputs = self.main(inputs).permute(0, 2, 1)
        # print(outputs.shape)
        return self.out_net(outputs)

class TextEncoderBiGRUCo(nn.Module):
    def __init__(self, word_size, pos_size, hidden_size, output_size, device):
        super(TextEncoderBiGRUCo, self).__init__()
        self.device = device

        self.pos_emb = nn.Linear(pos_size, word_size)
        self.input_emb = nn.Linear(word_size, hidden_size)
        self.gru = nn.GRU(hidden_size, hidden_size, batch_first=True, bidirectional=True)
        self.output_net = nn.Sequential(
            nn.Linear(hidden_size * 2, hidden_size),
            nn.LayerNorm(hidden_size),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(hidden_size, output_size)
        )

        self.input_emb.apply(init_weight)
        self.pos_emb.apply(init_weight)
        self.output_net.apply(init_weight)
        self.hidden_size = hidden_size
        self.hidden = nn.Parameter(torch.randn((2, 1, self.hidden_size), requires_grad=True))
    
    def forward(self, word_embs, pos_onehot, cap_lens):
        num_samples = word_embs.shape[0]

        pos_embs = self.pos_emb(pos_onehot)
        inputs = word_embs + pos_embs
        input_embs = self.input_emb(inputs)
        hidden = self.hidden.repeat(1, num_samples, 1)

        cap_lens = cap_lens.data.tolist()
        emb = pack_padded_sequence(input_embs, cap_lens, batch_first=True)

        gru_seq, gru_last = self.gru(emb, hidden)

        gru_last = torch.cat([gru_last[0], gru_last[1]], dim=-1)

        return self.output_net(gru_last)

class MotionEncoderBiGRUCo(nn.Module):
    def __init__(self, input_size, hidden_size, output_size, device):
        super(MotionEncoderBiGRUCo, self).__init__()
        self.device = device

        self.input_emb = nn.Linear(input_size, hidden_size)
        self.gru = nn.GRU(hidden_size, hidden_size, batch_first=True, bidirectional=True)
        self.output_net = nn.Sequential(
            nn.Linear(hidden_size*2, hidden_size),
            nn.LayerNorm(hidden_size),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(hidden_size, output_size)
        )

        self.input_emb.apply(init_weight)
        self.output_net.apply(init_weight)
        self.hidden_size = hidden_size
        self.hidden = nn.Parameter(torch.randn((2, 1, self.hidden_size), requires_grad=True))

    # input(batch_size, seq_len, dim)
    def forward(self, inputs, m_lens):
        num_samples = inputs.shape[0]

        input_embs = self.input_emb(inputs)
        hidden = self.hidden.repeat(1, num_samples, 1)

        cap_lens = m_lens.data.tolist()
        emb = pack_padded_sequence(input_embs, cap_lens, batch_first=True)

        gru_seq, gru_last = self.gru(emb, hidden)

        gru_last = torch.cat([gru_last[0], gru_last[1]], dim=-1)

        return self.output_net(gru_last)
    
def build_models(opt):
    movement_enc = MovementConvEncoder(opt.dim_pose, opt.dim_movement_enc_hidden, opt.dim_movement_latent)
    text_enc = TextEncoderBiGRUCo(word_size=opt.dim_word,
                                  pos_size=opt.dim_pos_ohot,
                                  hidden_size=opt.dim_text_hidden,
                                  output_size=opt.dim_coemb_hidden,
                                  device=opt.device)

    motion_enc = MotionEncoderBiGRUCo(input_size=opt.dim_movement_latent,
                                      hidden_size=opt.dim_motion_hidden,
                                      output_size=opt.dim_coemb_hidden,
                                      device=opt.device)

    checkpoint = torch.load(pjoin(opt.checkpoints_dir, 'text_mot_match', 'model', 'finest.tar'),
                            map_location=opt.device)
    movement_enc.load_state_dict(checkpoint['movement_encoder'])
    text_enc.load_state_dict(checkpoint['text_encoder'])
    motion_enc.load_state_dict(checkpoint['motion_encoder'])
    print('Loading Evaluation Model Wrapper (Epoch %d) Completed!!' % (checkpoint['epoch']))
    return text_enc, motion_enc, movement_enc

class EvaluatorModelWrapper(object):
    def __init__(self, opt):
        opt.dim_pose = 56 * 12

        opt.dim_word = 300
        opt.max_motion_length = 196
        opt.dim_pos_ohot = len(POS_enumerator)
        opt.dim_motion_hidden = 1024
        opt.max_text_len = 20
        opt.dim_text_hidden = 512
        opt.dim_coemb_hidden = 512
        if opt.dataset_name == 'hhi':
            opt.max_motion_length = 150
            opt.max_text_len = 35

        self.text_encoder, self.motion_encoder, self.movement_encoder = build_models(opt)
        self.opt = opt
        self.device = opt.device

        self.text_encoder.to(opt.device)
        self.motion_encoder.to(opt.device)
        self.movement_encoder.to(opt.device)

        self.text_encoder.eval()
        self.motion_encoder.eval()
        self.movement_encoder.eval()

    # Please note that the results does not following the order of inputs
    def prepare_motion(self, motions):
        motions[:,:,-1,9:] = 0
        motions[:,:,-1,3:6] = 0
        motions=motions.reshape(motions.shape[0], motions.shape[1],-1)
        return motions
    
    def cal_feature(self, motions):
        def tensor_to_device(tensor_dict, device):
            return {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in tensor_dict.items()}
        
        for person in ['person1', 'person2']:
            motions[person] = tensor_to_device(motions[person], self.device)

        features = {}
        for person in ['person1', 'person2']:
            body_pose = torch.cat([motions[person]['global_orient'].unsqueeze(-3), 
                                   motions[person]['body_pose'],
                                   motions[person]['jaw_pose'].unsqueeze(-3),
                                   motions[person]['left_eye_pose'].unsqueeze(-3),
                                   motions[person]['right_eye_pose'].unsqueeze(-3),
                                   motions[person]['left_hand_pose'],
                                   motions[person]['right_hand_pose']], dim=-3)
            rot_6d = transforms.matrix_to_rotation_6d(body_pose)[..., :-1, :, :]
            vel = motions[person]['transl'][..., 1:, :] - motions[person]['transl'][..., :-1, :]
            transl_vel = torch.cat([motions[person]['transl'][...,:-1, :], vel], dim=-1)
            features[person] = torch.cat([rot_6d, transl_vel.unsqueeze(-2)], dim=-2)  # [B, T, 56, 6]
        motions = torch.cat([features['person1'], features['person2']], dim=-1) # [B, T, 56, 2*6]
        
        return motions
    
    def get_co_embeddings(self, batch):
        word_embs, pos_ohot, _, cap_lens, motions, m_lens, _ = batch
        motions = self.cal_feature(motions)
        with torch.no_grad():
            word_embs = word_embs.detach().to(self.device).float()
            pos_ohot = pos_ohot.detach().to(self.device).float()
            motions = motions.detach().to(self.device).float()

            align_idx = np.argsort(m_lens.data.tolist())[::-1].copy()
            motions = motions[align_idx]
            m_lens = m_lens[align_idx]

            '''Movement Encoding'''
            motions = self.prepare_motion(motions)
            movements = self.movement_encoder(motions).detach()
            m_lens = m_lens // self.opt.unit_length
            motion_embedding = self.motion_encoder(movements, m_lens)

            '''Text Encoding'''
            text_embedding = self.text_encoder(word_embs, pos_ohot, cap_lens)
            text_embedding = text_embedding[align_idx]
        return text_embedding, motion_embedding

    # Please note that the results does not following the order of inputs
    def get_motion_embeddings(self, batch):
        try:
            _, _, _, sent_lens, motions, m_lens, _ = batch
        except ValueError:
            motions, m_lens = batch
        motions = self.cal_feature(motions)
        if len(motions.shape) == 5:
            motions = motions[0]
            m_lens = m_lens[0]
        with torch.no_grad():
            motions = motions.detach().to(self.device).float()

            align_idx = np.argsort(m_lens.data.tolist())[::-1].copy()
            motions = motions[align_idx]
            m_lens = m_lens[align_idx]

            '''Movement Encoding'''
            motions = self.prepare_motion(motions)
            movements = self.movement_encoder(motions).detach()
            m_lens = m_lens // self.opt.unit_length
            motion_embedding = self.motion_encoder(movements, m_lens)
        return motion_embedding

