import torch
import random
import copy
import os
import json
import numpy as np
from tqdm import tqdm
from pytorch3d import transforms
from torch.utils.data import Dataset, DataLoader
from data_loaders.tensors import collate as all_collate
from data_loaders.tensors import t2m_collate

from utils.smpl_utils import *

def get_collate_fn(name, hml_mode='train'):
    if hml_mode == 'gt':
        from data_loaders.humanml.data.dataset import collate_fn as t2m_eval_collate
        return t2m_eval_collate
    if name in ["humanml", "kit"]:
        return t2m_collate
    else:
        return all_collate

def mp_collate(batch):
    # sort batch by gender
    # batch = sorted(batch, key=lambda x: x['gender'])
    new_idx = []
    for gender in ['female', 'male']:
        new_idx = new_idx + [idx for idx in range(len(batch)) if batch[idx]['gender'] == gender]
    batch = [batch[i] for i in new_idx]

    text_batch = [b['text'] for b in batch]
    gender_batch = [b['gender'] for b in batch]
    betas_batch = torch.stack([b['betas'] for b in batch], dim=0)  # (B, T, 10)
    motion_batch = torch.stack([b['motion_tensor_normalized'] for b in batch], dim=0)  # (B, D, 1, T)
    history_mask_batch = torch.stack([b['history_mask'] for b in batch], dim=0)
    history_motion_batch = torch.stack([b['history_motion'] for b in batch], dim=0)

    motion = motion_batch
    cond = {'y': {'text': text_batch, 'gender': gender_batch, 'betas': betas_batch,
                  'history_motion': history_motion_batch, 'history_mask': history_mask_batch,
                  'history_length': batch[0]['history_length'], 'future_length': batch[0]['future_length']
                  }
            }
    return motion, cond


def get_dataset_loader_mp(dataset_path, batch_size, split='train'):
    from data_loaders.humanml.data.dataset import Text2MotionPrimitiveDataset
    dataset = Text2MotionPrimitiveDataset(dataset_path=dataset_path, split=split, load_data=True)

    loader = DataLoader(
        dataset, batch_size=batch_size, shuffle=True if split == 'train' else False,
        num_workers=8, drop_last=True, collate_fn=mp_collate
    )

    return loader


def tensor_dict_to_device(tensor_dict, device, dtype=torch.float32):
    for k in tensor_dict:
        if isinstance(tensor_dict[k], torch.Tensor):
            tensor_dict[k] = tensor_dict[k].to(device=device)
    return tensor_dict

def extract_replication(data, indices):
    if isinstance(indices, int):
        indices = [indices]
    
    extracted = {}
    for person in ['person1', 'person2']:
        extracted[person] = {}
        for key, value in data[person].items():
            if isinstance(value, list):
                extracted[person][key] = [value[i] for i in indices]
            elif isinstance(value, torch.Tensor):
                extracted[person][key] = value[indices]
            else:
                raise TypeError(f"Unsupported data type for key '{key}': {type(value)}")
    return extracted

class EvaluationDataset(Dataset):
    def __init__(self, 
                 args, 
                 denoiser_model, 
                 denoiser_args, 
                 diffusion, 
                 diffusion_args, 
                 vae_model, 
                 vae_args, 
                 data_args, 
                 ground_truth_dataset, 
                 device, 
                 mm_num_samples, 
                 mm_num_repeats,
                 load_data=True,
                 load_dir='',
                 **kwargs,
                 ):
        self.args = args
        self.denoiser_model = denoiser_model
        self.denoiser_model.eval()
        self.denoiser_args = denoiser_args
        self.diffusion = diffusion
        self.diffusion_args = diffusion_args
        self.vae_model = vae_model
        self.vae_args = vae_args
        self.data_args = data_args
        self.dataset = ground_truth_dataset
        dataloader = DataLoader(ground_truth_dataset, batch_size=1, num_workers=0, shuffle=True)
        self.device = device
        self.mm_num_samples = mm_num_samples
        self.mm_num_repeats = mm_num_repeats
        self.max_length = self.dataset.max_length
        self.replication_times = kwargs.get('replication_times', 1)
        self.replication = kwargs.get('replication', 0)
        # self.primitive_utility = self.dataset.primitive_utility
        self.primitive_utility = PrimitiveUtility(device=self.device, 
                                                  body_type=self.dataset.primitive_utility.body_type, 
                                                  motion_repr=self.dataset.primitive_utility.motion_repr)
        print('body_type:', self.primitive_utility.body_type)
        
        self.text_encoder = kwargs.get('text_encoder', None)
        self.text_encoder_version = getattr(denoiser_args, 'text_encoder_version', 'v1')
        
        if load_data:
            if args.load_from_file:
                with open(os.path.join(load_dir, "mm_indices.json")) as f:
                    mm_idxs = set(json.load(f))
            else:
                idxs = list(range(len(self.dataset)))
                random.shuffle(idxs)
                mm_idxs = idxs[:mm_num_samples]
        

            generated_motions = []
            mm_generated_motions = []
            # Pre-process all target captions
            with torch.no_grad():
                for i, data in tqdm(enumerate(dataloader)):
                    seq_name, interaction_text, batch_data, motion_lens = data
                    seq_name = seq_name[0]
                    # if i > 32:
                    #     break
                    
                    if args.load_from_file:
                        path = os.path.join(load_dir, f"motion_{seq_name}.pt")
                        motion_sequences = torch.load(path)
                        if self.replication_times > 1:
                            if i in mm_idxs:
                                motion_sequences = extract_replication(motion_sequences, range(self.replication*self.mm_num_repeats,(self.replication+1) * self.mm_num_repeats))
                            else:
                                motion_sequences = extract_replication(motion_sequences, self.replication)
                            
                    else:
                        if i in mm_idxs:
                            motion_sequences = self.generate_motion(batch_data, interaction_text, motion_lens, mm_num_repeats)
                        else:
                            motion_sequences = self.generate_motion(batch_data, interaction_text, motion_lens, 1)
                    
                    first_batch_motion = {}
                    if self.data_args.interaction:
                        for person in ['person1', 'person2']:
                            first_batch_motion[person] = {}
                            for key, value in motion_sequences[person].items():
                                if isinstance(value, torch.Tensor):
                                    first_batch_motion[person][key] = value[0]
                                elif isinstance(value, list):
                                    first_batch_motion[person][key] = [value[0]]
                                else:
                                    first_batch_motion[person][key] = value
                    else:
                        for key, value in motion_sequences.items():
                            if isinstance(value, torch.Tensor):
                                first_batch_motion[key] = value[0]
                            elif isinstance(value, list):
                                first_batch_motion[key] = [value[0]]
                            else:
                                first_batch_motion[key] = value
                    
                    sub_dict = {'motion': first_batch_motion,
                                'motion_lens': motion_lens[0],
                                'text': motion_sequences['person1']['texts'][0] if self.data_args.interaction else motion_sequences['texts'][0]}
                    generated_motions.append(sub_dict)
                    # if i in mm_idxs:
                    if seq_name in mm_idxs:
                        mm_sub_dict = {'mm_motions': motion_sequences,
                                    'motion_lens': motion_lens[0],
                                    'text': motion_sequences['person1']['texts'][0] if self.data_args.interaction else motion_sequences['texts'][0]}
                        mm_generated_motions.append(mm_sub_dict)


            self.generated_motions = generated_motions
            self.mm_generated_motions = mm_generated_motions

    def preprocess(self, seq_data, interaction_text):
        # Canonicalization
        canonicalized_primitive_dict, transf_rotmat, transf_transl = {}, {}, {}
        for person in ['person1', 'person2']:
            transf_rotmat[person], transf_transl[person], canonicalized_primitive_dict[person] = self.primitive_utility.canonicalize(copy.deepcopy(seq_data[person]), use_predicted_joints=True)
                    
        if self.dataset.mode == 'merged':
            # rel_pose
            rel_rotmat, rel_transl, rel_pose = {}, {}, {}
            rel_rotmat['b2a'], rel_transl['b2a'] = self.primitive_utility.compute_rel_transform_B_in_A(transf_rotmat['person1'], transf_transl['person1'], transf_rotmat['person2'], transf_transl['person2'])
            rel_rotmat['a2b'], rel_transl['a2b'] = self.primitive_utility.compute_rel_transform_B_in_A(transf_rotmat['person2'], transf_transl['person2'], transf_rotmat['person1'], transf_transl['person1'])
            for rel in ['b2a', 'a2b']:
                rel_rotmat[rel] = transforms.matrix_to_rotation_6d(rel_rotmat[rel])
                rel_pose[rel] = torch.cat([rel_rotmat[rel], rel_transl[rel].squeeze(1)], dim=-1)  # [1, 6+3]
        
        data_batch = {} if self.dataset.mode == 'merged' else []
        # calculate features
        feature_dict = {}
        for person in ['person1', 'person2']:
            feature_dict[person] = self.primitive_utility.calc_features(canonicalized_primitive_dict[person], use_predicted_joints=True)
            if self.primitive_utility.feature_dim == 276:
                feature_dict[person]['transl'] = feature_dict[person]['transl'][:, :-1, :]  # [1, T, 3]
                feature_dict[person]['poses_6d'] = feature_dict[person]['poses_6d'][:, :-1, :]  # [1, T, 66]
                feature_dict[person]['joints'] = feature_dict[person]['joints'][:, :-1, :]  # [1, T, 22 * 3]
        
        if self.dataset.mode == 'merged':
            for person in ['person1', 'person2']:
                motion_tensor_normalized = self.dataset.normalize(self.primitive_utility.dict_to_tensor(feature_dict[person]))      # [1, T, D]
                motion_tensor_normalized = motion_tensor_normalized.permute(0, 2, 1).unsqueeze(2)                           # [1, D, 1, T]
                
                if self.dataset.use_indi_text:    
                    texts = [random.choice([itext['proc_label'] for itext in seq_data[f'frame_labels_{person}']])]
                    query_person = (
                        'person2' if seq_data['exchange'] and person == 'person1' else
                        'person1' if seq_data['exchange'] and person == 'person2' else
                        person
                    )
                    text_embedding = torch.stack([self.dataset.text_embedding_dict[query_person][text] for text in texts], dim=0)  # [1, 512]
                    text_mask = torch.stack([self.dataset.text_mask_dict[query_person][text] for text in texts], dim=0) if self.dataset.text_sep else None
                            
                data_batch[person] = {
                        'gender': seq_data[person]['gender'],
                        'betas': seq_data[person]['betas'],
                        'motion_tensor_normalized': motion_tensor_normalized,        # [1, D, 1, T]
                        'transf_rotmat': transf_rotmat[person],
                        'transf_transl': transf_transl[person],
                        'history_length': self.dataset.history_length,
                        'future_length': self.dataset.future_length,
                        'padding_mask': seq_data[person]['padding_mask'],
                    }
                if self.dataset.use_indi_text:
                    data_batch[person]['texts'] = texts[0]
                if self.dataset.load_text_embedding and self.dataset.use_indi_text:
                    data_batch[person]['text_embedding'] = text_embedding.detach().cpu()
                    if self.dataset.text_sep:
                        data_batch[person]['text_mask'] = text_mask.detach().cpu()
            texts = [interaction_text[0]]
            data_batch['interaction'] = {
                'texts': texts, 
                'rel_pose_b2a': rel_pose['b2a'],
                'rel_pose_a2b': rel_pose['a2b'],
            }
            if self.dataset.load_text_embedding:
                text_embedding = torch.stack([self.dataset.text_embedding_dict['interaction'][text] for text in texts], dim=0)  # [1, 512]
                text_mask = torch.stack([self.dataset.text_mask_dict['interaction'][text] for text in texts], dim=0) if self.dataset.text_sep else None
                data_batch['interaction']['text_embedding'] = text_embedding.detach().cpu()
                if self.dataset.text_sep:
                    data_batch['interaction']['text_mask'] = text_mask.detach().cpu()
            return data_batch
        else:
            for person in ['person1', 'person2']:
                motion_tensor_normalized = self.dataset.normalize(self.primitive_utility.dict_to_tensor(feature_dict[person]))     # [1, T, D]
                motion_tensor_normalized = motion_tensor_normalized.permute(0, 2, 1).unsqueeze(2)           # [1, D, 1, T]
                
                texts = [random.choice([itext['proc_label'] for itext in seq_data[f'frame_labels_{person}']])]
                # unseen_texts = [text for text in texts if text not in self.text_embedding_dict[person]]
                # if len(unseen_texts) > 0:
                #     self.update_text_embedding_dict(unseen_texts, person, text_sep=self.text_sep, max_segs=self.max_segs)
                data_batch.append(
                    {
                        'texts': texts[0],
                        'gender': seq_data[person]['gender'],
                        'betas': seq_data[person]['betas'],
                        'motion_tensor_normalized': motion_tensor_normalized.squeeze(0), # [1, D, 1, T]
                        'transf_rotmat': transf_rotmat[person],
                        'transf_transl': transf_transl[person],
                        'history_length': self.dataset.history_length,
                        'future_length': self.dataset.future_length,
                        'padding_mask': seq_data[person]['padding_mask'],
                    }
                )
                if self.dataset.load_text_embedding:
                    text_embedding = torch.stack([self.dataset.text_embedding_dict[person][text] for text in texts], dim=0)  # [1, 512]
                    text_mask = torch.stack([self.dataset.text_mask_dict[person][text] for text in texts], dim=0) if self.dataset.text_sep else None
                    data_batch[-1]['text_embedding'] = text_embedding.detach().cpu()
                    data_batch[-1]['text_mask'] = text_mask.detach().cpu()
            choice = 0 if random.random() < 0.5 else 1
            return data_batch[choice]

    def generate_motion(self, batch_data, interaction_text, motion_lens, batch_size):
        """
        Generate motion for a given batch of data.
        :param batch_data: A dictionary containing the input data for the model.
        :return: A dictionary containing the generated motion and other relevant information.
        """
        device = self.device
        future_length = self.dataset.future_length
        history_length = self.dataset.history_length
        primitive_length = history_length + future_length
        sample_fn = self.diffusion.p_sample_loop if self.diffusion_args.respacing == '' else self.diffusion.ddim_sample_loop
        
        batch_data = self.preprocess(batch_data, interaction_text)

        lengths = motion_lens + 1
        
        def expand_data(data, batch_size):
            for key in data.keys():
                if key in ['texts', 'gender']:
                    data[key] *= batch_size
                elif key in ['text_embedding', 'text_mask', 'betas', 'motion_tensor_normalized', 'transf_rotmat', 'transf_transl', 'padding_mask', 'rel_pose_b2a', 'rel_pose_a2b']:
                    data[key] = data[key].repeat(batch_size, *([1] * (data[key].dim() - 1)))
                else:
                    continue
        
        if self.data_args.interaction:
            for person in ['person1', 'person2', 'interaction']:
                expand_data(batch_data[person], batch_size)
        else:
            expand_data(batch_data, batch_size)
        
        if self.data_args.interaction:
            input_motions, model_kwargs, gender, betas, pelvis_delta, motion_tensor, history_motion_gt = {}, {}, {}, {}, {}, {}, {}
            for person in ['person1', 'person2']:
                input_motions[person] = batch_data[person]['motion_tensor_normalized']
                model_kwargs[person] = {'y': batch_data[person]}
                del model_kwargs[person]['y']['motion_tensor_normalized']
                gender[person] = model_kwargs[person]['y']['gender']
                betas[person] = model_kwargs[person]['y']['betas'][:, :primitive_length, :10].to(device) # [B, H+F, 10]
                pelvis_delta[person] = self.primitive_utility.calc_calibrate_offset({
                    'betas': betas[person][:, 0, :],
                    'gender': gender[person],
                })
                input_motions[person] = input_motions[person].to(device)  # [B, D, 1, T]
                motion_tensor[person] = input_motions[person].squeeze(2).permute(0, 2, 1)  # [B, T, D]
                history_motion_gt[person] = motion_tensor[person][:, :history_length, :]  # [B, H, D]
            interaction = batch_data['interaction']
        else:
            input_motions, model_kwargs = batch_data['motion_tensor_normalized'], {'y': batch_data}
            del model_kwargs['y']['motion_tensor_normalized']
            gender = model_kwargs['y']['gender']
            betas = model_kwargs['y']['betas'][:, :primitive_length, :].to(device)  # [B, H+F, 10]
            pelvis_delta = self.primitive_utility.calc_calibrate_offset({
                'betas': betas[:, 0, :],
                'gender': gender,
            })
            input_motions = input_motions.to(device)  # [B, D, 1, T]
            motion_tensor = input_motions.squeeze(2).permute(0, 2, 1)  # [B, T, D]
            history_motion_gt = motion_tensor[:, :history_length, :]  # [B, H, D]

        motion_sequences = None if not self.data_args.interaction else {'person1': None, 'person2': None}
        history_motion = history_motion_gt
        transf_rotmat = {
            'person1': batch_data['person1']['transf_rotmat'].to(device),
            'person2': batch_data['person2']['transf_rotmat'].to(device),
        } if self.data_args.interaction else batch_data['transf_rotmat'].to(device)
        transf_transl = {
            'person1': batch_data['person1']['transf_transl'].to(device),
            'person2': batch_data['person2']['transf_transl'].to(device),
        } if self.data_args.interaction else batch_data['transf_transl'].to(device)
        
        if self.args.fix_floor:
            if self.data_args.interaction:
                motion_dict = {}
                for person in ['person1', 'person2']:
                    motion_dict[person] = self.primitive_utility.tensor_to_dict(self.dataset.denormalize(history_motion_gt[person]))
                    joints = motion_dict[person]['joints'].reshape(batch_size, history_length, 22, 3)  # [B, T, 22, 3]   
                    init_floor_height = joints[:, 0, :, 2].amin(dim=-1)  # [B]
                    transf_transl[person][:, :, 2] = -init_floor_height.unsqueeze(-1)
            else:
                motion_dict = self.primitive_utility.tensor_to_dict(self.dataset.denormalize(history_motion_gt))
                joints = motion_dict['joints'].reshape(batch_size, history_length, 22, 3)  # [B, T, 22, 3]
                init_floor_height = joints[:, 0, :, 2].amin(dim=-1)  # [B]
                transf_transl[:, :, 2] = -init_floor_height.unsqueeze(-1)

        num_primitives = int(np.ceil((lengths-history_length) / future_length))
        guidance_param = torch.ones(batch_size, *self.denoiser_args.model_args.noise_shape).to(device=self.device) * self.args.guidance_param
        if self.denoiser_args.use_pre_latent:
            pre_latent = [] if not self.data_args.interaction else {'person1': [], 'person2': []}
            pre_transf_rotmat_abs = [] if not self.data_args.interaction else {'person1': [], 'person2': []}
            pre_transf_transl_abs = [] if not self.data_args.interaction else {'person1': [], 'person2': []}

        for primitive_id in range(num_primitives):
            valid_length = min(future_length, lengths - (primitive_id*future_length+history_length))
            if self.data_args.interaction:
                if self.denoiser_args.merge_his_relpose:
                    history_motion_rel = {}
                    for person in ['person1', 'person2']:
                        history_motion_rel[person] = copy.deepcopy(history_motion[person])
                        his_motion_rel_denormalized = self.dataset.denormalize(history_motion_rel[person])
                        his_motion_rel = self.primitive_utility.relative_transform_feature_tensor(
                            his_motion_rel_denormalized.to(device),
                            transforms.rotation_6d_to_matrix(interaction['rel_pose_a2b' if person == 'person1' else 'rel_pose_b2a'][:, :6]).to(device),
                            interaction['rel_pose_a2b' if person == 'person1' else 'rel_pose_b2a'][:, 6:9].unsqueeze(1).to(device),
                            batch_data[person]['gender'],
                            batch_data[person]['betas'][:, 0].to(device),
                        )
                        history_motion_rel[person] = self.dataset.normalize(his_motion_rel)  # [B, H, D]

                latent_pred, future_motion_pred, future_frames, all_frames, valid_future_frames, new_history_frames = {}, {}, {}, {}, {}, {}
                for person in ['person1', 'person2']:
                    y = {
                        'history_motion_normalized': history_motion[person],
                        'history_motion_normalized_b': history_motion['person2' if person == 'person1' else 'person1'],
                        'text_inter': batch_data['interaction']['texts'],
                        'rel_pose': interaction['rel_pose_'+'b2a' if person == 'person1' else 'rel_pose_'+'a2b'].to(device),
                        'scale': guidance_param,
                    }
                    if self.denoiser_args.load_text_embedding:
                        if self.denoiser_args.use_indi_text:
                            y['text_embedding'] = batch_data[person]['text_embedding'].to(device)
                        y['text_embedding_inter'] = batch_data['interaction']['text_embedding'].to(device)
                    else:
                        if self.text_encoder_version == 'v1':
                            if self.denoiser_args.use_indi_text:
                                y['text_embedding'] = self.text_encoder(batch_data[person]['texts']).to(device)
                            y['text_embedding_inter'] = self.text_encoder(batch_data['interaction']['texts']).to(device)
                        elif self.text_encoder_version in ['v2', 'v3']:
                            if self.denoiser_args.use_indi_text:
                                y['text_embedding'], batch_data[person]['text_mask'] = self.text_encoder(batch_data[person]['texts'])
                                y['text_embedding'] = y['text_embedding'].to(device)
                            y['text_embedding_inter'], batch_data['interaction']['text_mask'] = self.text_encoder(batch_data['interaction']['texts'])
                            y['text_embedding_inter'] = y['text_embedding_inter'].to(device)
                    if self.denoiser_args.merge_his_relpose:
                        y['history_motion_normalized_b'] = history_motion_rel['person2' if person == 'person1' else 'person1']
                    if self.denoiser_args.text_sep:
                        if self.denoiser_args.use_indi_text:
                            y['text_mask'] = batch_data[person]['text_mask'].to(device)
                        y['text_mask_inter'] = batch_data['interaction']['text_mask'].to(device)
                    if self.denoiser_args.use_pre_latent:
                        y['pre_latent'] = torch.cat(pre_latent[person], dim=1) if len(pre_latent[person])!=0 else None
                        if len(pre_transf_rotmat_abs[person])==0:
                            y['pre_reltrans'] = None
                        else:
                            y['pre_reltrans'] = []
                            for transf_rotmat_abs, transf_transl_abs in zip(pre_transf_rotmat_abs[person], pre_transf_transl_abs[person]):
                                rel_rotmat, rel_transl = self.primitive_utility.compute_rel_transform_B_in_A(
                                    transf_rotmat[person], transf_transl[person], transf_rotmat_abs, transf_transl_abs)
                                y['pre_reltrans'].append(torch.cat([transforms.matrix_to_rotation_6d(rel_rotmat), rel_transl.squeeze(1)], dim=-1).unsqueeze(1))
                            y['pre_reltrans'] = torch.cat(y['pre_reltrans'], dim=1)  # [B, num_primitive, 6+3]

                    x_start_pred = sample_fn(
                        self.denoiser_model,
                        (batch_size, *self.denoiser_args.model_args.noise_shape),
                        clip_denoised=False,
                        model_kwargs={'y': y},
                        skip_timesteps=0,  # 0 is the default value - i.e. don't skip any step
                        init_image=None,
                        progress=False,
                        dump_steps=None,
                        noise=torch.zeros_like(guidance_param) if self.args.zero_noise else None,
                        const_noise=False,
                    )  # [B, T=1, D]
                    latent_pred[person] = x_start_pred.permute(1, 0, 2)  # [T=1, B, D]
                    future_motion_pred[person] = self.vae_model.decode(latent_pred[person], history_motion[person], nfuture=future_length,
                                                                scale_latent=self.denoiser_args.rescale_latent)
                    future_frames[person] = self.dataset.denormalize(future_motion_pred[person])
                    future_start, future_end = 0, valid_length
                    valid_future_frames[person] = future_frames[person][:, future_start:future_end, :]  # ignore the initial standing seed
                    all_frames[person] = torch.cat([self.dataset.denormalize(history_motion[person]), valid_future_frames[person]], dim=1)
                    new_history_end = history_length + valid_length
                    new_history_start = new_history_end - history_length
                    new_history_frames[person] = all_frames[person][:, new_history_start:new_history_end, :]
                    
                    if self.denoiser_args.use_pre_latent:
                        """store pre latent and global transform"""
                        pre_latent[person].append(latent_pred[person].permute(1, 0, 2))
                        pre_transf_rotmat_abs[person].append(transf_rotmat[person])
                        pre_transf_transl_abs[person].append(transf_transl[person])
                    
                    """transform primitive to world coordinate, prepare for serialization"""
                    if motion_sequences[person] is None:
                        all_feature_dict = self.primitive_utility.tensor_to_dict(all_frames[person])
                        all_feature_dict.update(
                            {
                                'transf_rotmat': transf_rotmat[person],
                                'transf_transl': transf_transl[person],
                                'gender': gender[person],
                                'betas': betas[person][:, :history_length+future_end, :],
                                'pelvis_delta': pelvis_delta[person],
                            }
                        )
                        all_primitive_dict = self.primitive_utility.feature_dict_to_smpl_dict(all_feature_dict)
                        all_primitive_dict = self.primitive_utility.transform_primitive_to_world(all_primitive_dict)
                        motion_sequences[person] = all_primitive_dict
                    else:
                        future_feature_dict = self.primitive_utility.tensor_to_dict(valid_future_frames[person])
                        future_feature_dict.update(
                            {
                                'transf_rotmat': transf_rotmat[person],
                                'transf_transl': transf_transl[person],
                                'gender': gender[person],
                                'betas': betas[person][:, future_start:future_end, :],
                                'pelvis_delta': pelvis_delta[person],
                            }
                        )
                        future_primitive_dict = self.primitive_utility.feature_dict_to_smpl_dict(future_feature_dict)
                        future_primitive_dict = self.primitive_utility.transform_primitive_to_world(future_primitive_dict)
                        for key in motion_sequences[person].keys():
                            if key in ['transl', 'global_orient', 'body_pose', 'betas', 'joints']:
                                motion_sequences[person][key] = torch.cat([motion_sequences[person][key], future_primitive_dict[key]], dim=1)  # [B, T, ...]

                    """update history motion seed, update global transform"""
                    if primitive_id < num_primitives - 1:  # not the last primitive
                        history_feature_dict = self.primitive_utility.tensor_to_dict(new_history_frames[person])
                        history_feature_dict.update(
                            {
                                'transf_rotmat': transf_rotmat[person],
                                'transf_transl': transf_transl[person],
                                'gender': gender[person],
                                'betas': betas[person][:, new_history_start:new_history_end, :],
                                'pelvis_delta': pelvis_delta[person],
                            }
                        )
                        if self.args.fix_floor and primitive_id == num_primitives - 1:  # fix the first frame feet of each segment to be on floor
                            foot_height = history_feature_dict['joints'].reshape(-1, history_length, 22, 3)[:, 0, FOOT_JOINTS_IDX, 2].amin(dim=-1)  # [B]
                            foot_height_world = foot_height + history_feature_dict['transf_transl'][:, 0, 2]  # [B]
                            history_feature_dict['transf_transl'][:, 0, 2] -= foot_height_world
                        canonicalized_history_primitive_dict, blended_feature_dict = self.primitive_utility.get_blended_feature(
                            history_feature_dict, use_predicted_joints=self.args.use_predicted_joints)
                        transf_rotmat[person], transf_transl[person] = canonicalized_history_primitive_dict['transf_rotmat'], \
                        canonicalized_history_primitive_dict['transf_transl']
                        history_motion[person] = self.primitive_utility.dict_to_tensor(blended_feature_dict)
                        history_motion[person] = self.dataset.normalize(history_motion[person])  # [B, T, D]
                if primitive_id < num_primitives - 1:
                    rel_rotmat, rel_transl, rel_pose = {}, {}, {}
                    rel_rotmat['b2a'], rel_transl['b2a'] = self.primitive_utility.compute_rel_transform_B_in_A(
                        transf_rotmat['person1'], transf_transl['person1'],transf_rotmat['person2'], transf_transl['person2'])
                    rel_rotmat['a2b'], rel_transl['a2b'] = self.primitive_utility.compute_rel_transform_B_in_A(
                        transf_rotmat['person2'], transf_transl['person2'],transf_rotmat['person1'], transf_transl['person1'])
                    for rel in ['b2a', 'a2b']:
                        rel_rotmat[rel] = transforms.matrix_to_rotation_6d(rel_rotmat[rel])
                        rel_pose[rel] = torch.cat([rel_rotmat[rel], rel_transl[rel].squeeze(1)], dim=-1)  # [B*num_mp, 6+3]
                    interaction['rel_pose_b2a'] = rel_pose['b2a']
                    interaction['rel_pose_a2b'] = rel_pose['a2b']
            else:
                y = {
                    'history_motion_normalized': history_motion,
                    'scale': guidance_param,
                }
                if self.denoiser_args.load_text_embedding:
                    y['text_embedding'] = batch_data['text_embedding'].to(device)
                else:
                    if self.text_encoder_version == 'v1':
                        y['text_embedding'] = self.text_encoder(batch_data['texts']).to(device)
                    elif self.text_encoder_version in ['v2', 'v3']:
                        y['text_embedding'], batch_data['text_mask'] = self.text_encoder(batch_data['texts'])
                        y['text_embedding'] = y['text_embedding'].to(device)
                if self.denoiser_args.text_sep:
                    y['text_mask'] = batch_data['text_mask'].to(device)

                x_start_pred = sample_fn(
                    self.denoiser_model,
                    (batch_size, *self.denoiser_args.model_args.noise_shape),
                    clip_denoised=False,
                    model_kwargs={'y': y},
                    skip_timesteps=0,  # 0 is the default value - i.e. don't skip any step
                    init_image=None,
                    progress=False,
                    dump_steps=None,
                    noise=torch.zeros_like(guidance_param) if self.args.zero_noise else None,
                    const_noise=False,
                )  # [B, T=1, D]
                latent_pred = x_start_pred.permute(1, 0, 2)  # [T=1, B, D]
                future_motion_pred = self.vae_model.decode(latent_pred, history_motion, nfuture=future_length,
                                                    scale_latent=self.denoiser_args.rescale_latent)  # [B, F, D], normalized
                future_frames = self.dataset.denormalize(future_motion_pred)
                all_frames = torch.cat([self.dataset.denormalize(history_motion), future_frames], dim=1)
                
                future_start, future_end = 0, valid_length
                valid_future_frames = future_frames[:, future_start:future_end, :]  # ignore the initial standing seed
                new_history_end = history_length + valid_length
                new_history_start = new_history_end - history_length
                new_history_frames = all_frames[:, new_history_start:new_history_end, :]
                
                """transform primitive to world coordinate, prepare for serialization"""
                future_feature_dict = self.primitive_utility.tensor_to_dict(valid_future_frames)
                future_feature_dict.update(
                    {
                        'transf_rotmat': transf_rotmat,
                        'transf_transl': transf_transl,
                        'gender': gender,
                        'betas': betas[:, future_start:future_end, :],
                        'pelvis_delta': pelvis_delta,
                    }
                )
                future_primitive_dict = self.primitive_utility.feature_dict_to_smpl_dict(future_feature_dict)
                future_primitive_dict = self.primitive_utility.transform_primitive_to_world(future_primitive_dict)
                if motion_sequences is None:
                    motion_sequences = future_primitive_dict
                else:
                    for key in motion_sequences.keys():
                        if key in ['transl', 'global_orient', 'body_pose', 'betas', 'joints']:
                            motion_sequences[key] = torch.cat([motion_sequences[key], future_primitive_dict[key]], dim=1)  # [B, T, ...]

                """update history motion seed, update global transform"""
                history_feature_dict = self.primitive_utility.tensor_to_dict(new_history_frames)
                history_feature_dict.update(
                    {
                        'transf_rotmat': transf_rotmat,
                        'transf_transl': transf_transl,
                        'gender': gender,
                        'betas': betas[:, new_history_start:new_history_end, :],
                        'pelvis_delta': pelvis_delta,
                    }
                )
                if self.args.fix_floor and primitive_id == num_primitives - 1:  # fix the first frame feet of each segment to be on floor
                    foot_height = history_feature_dict['joints'].reshape(-1, history_length, 22, 3)[:, 0, FOOT_JOINTS_IDX, 2].amin(dim=-1)  # [B]
                    foot_height_world = foot_height + history_feature_dict['transf_transl'][:, 0, 2]  # [B]
                    history_feature_dict['transf_transl'][:, 0, 2] -= foot_height_world
                canonicalized_history_primitive_dict, blended_feature_dict = self.primitive_utility.get_blended_feature(
                    history_feature_dict, use_predicted_joints=self.args.use_predicted_joints)
                transf_rotmat, transf_transl = canonicalized_history_primitive_dict['transf_rotmat'], \
                canonicalized_history_primitive_dict['transf_transl']
                history_motion = self.primitive_utility.dict_to_tensor(blended_feature_dict)
                history_motion = self.dataset.normalize(history_motion)  # [B, T, D]
        def add_text(motion_sequences, texts):
            motion_sequences['texts'] = texts
            
        if self.data_args.interaction:
            for person in ['person1', 'person2']:
                add_text(motion_sequences[person], batch_data['interaction']['texts'])
        else:
            add_text(motion_sequences, batch_data['texts'])
        
        # padding
        def pad_motion(motion_sequences, padding_length):
            for key in motion_sequences.keys():
                if key in ['betas', 'transl', 'joints', 'global_orient', 'body_pose']:
                    last_frame = motion_sequences[key][:, -1:]
                    padding = last_frame.repeat(1, padding_length, *([1] * (motion_sequences[key].dim() - 2)))
                    motion_sequences[key] = torch.cat([motion_sequences[key], padding], dim=1)
                else:
                    continue
            return motion_sequences
        
        if lengths < self.max_length + 1:
            padding_length = self.max_length + 1 - lengths
            if self.data_args.interaction:
                for person in ['person1', 'person2']:
                    motion_sequences[person] = pad_motion(motion_sequences[person], padding_length)
            else:
                motion_sequences = pad_motion(motion_sequences, padding_length)
        if motion_lens >= 300:
            print(motion_lens, motion_sequences['person1']['joints'].shape)
        
        # for person in ['person1', 'person2']:
        #     for key in motion_sequences[person].keys():
        #         if key in ['betas', 'transl', 'global_orient', 'body_pose', 'pelvis_delta', 'joints', 'transf_rotmat', 'transf_transl']:
        #             motion_sequences[person][key] = motion_sequences[person][key].squeeze(0)   
        
        if self.data_args.interaction:
            for person in ['person1', 'person2']:
                tensor_dict_to_device(motion_sequences[person], 'cpu')
        else:
            tensor_dict_to_device(motion_sequences, 'cpu')
        tensor_dict_to_device(motion_sequences, 'cpu')
        return motion_sequences
    
    def __len__(self):
        return len(self.generated_motions)

    def __getitem__(self, item):
        data = self.generated_motions[item]
        motion, motion_lens, text = data['motion'], data['motion_lens'], data['text']
        return "generated", text, motion, motion_lens


class MMGeneratedDataset(Dataset):
    def __init__(self, motion_dataset):
        self.mm_num_repeats = motion_dataset.mm_num_repeats
        self.dataset = motion_dataset.mm_generated_motions

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, item):
        data = self.dataset[item]
        mm_motions, motion_lens, text = data['mm_motions'], data['motion_lens'], data['text']
        motion_lens = np.array([motion_lens]*self.mm_num_repeats)
        return "mm_generated", text, mm_motions, motion_lens


def get_dataset_motion_loader(data_args, batch_size, device, split, mode, text_sep, cut_length=0, **kwargs):
    # Configurations of T2M dataset and KIT dataset is almost the same
    if data_args.dataset == 'interhuman' or data_args.dataset == 'interhuman_d262':
        motion_repr = {
            'transl': 3,
            'poses_6d': 22 * 6,
            'transl_delta': 3,
            'global_orient_delta_6d': 6,
            'joints': 22 * 3,
            'joints_delta': 22 * 3,
        } if data_args.dataset == 'interhuman' else {
            'joints': 22 * 3,
            'joints_delta': 22 * 3,
            'body_pose': 21 * 6,
            'feet_contact': 4,
        }
        print('Loading dataset %s ...' % data_args.dataset)
        from data_loaders.HHI.data.dataset_interhuman import InterHumanDatasetEvalV2
        clip_version = kwargs.get('clip_version', 'ViT-B/32')
        load_text_embedding = kwargs.get('load_text_embedding', False)
        use_indi_text = kwargs.get('use_indi_text', False)
        dataset = InterHumanDatasetEvalV2(
            dataset_name=data_args.dataset,
            dataset_path=data_args.data_dir,
            cfg_path=data_args.cfg_path,
            split=split,
            device=device,
            enforce_gender=data_args.enforce_gender,
            enforce_zero_beta=data_args.enforce_zero_beta,
            body_type=data_args.body_type,
            mode=mode,
            text_sep=text_sep,
            min_length=data_args.min_length,
            max_length=data_args.max_length,
            motion_repr=motion_repr,
            padding=data_args.padding,
            cut_length=cut_length,
            clip_version=clip_version,
            load_text_embedding=load_text_embedding,
            use_indi_text=use_indi_text,)
        dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=0, drop_last=True, shuffle=True)
    else:
        raise KeyError('Dataset not Recognized !!')

    print('Ground Truth Dataset Loading Completed!!!')
    return dataloader, dataset


def get_motion_loader(args, batch_size, model, denoiser_args, diffusion, diffusion_args, vae_model, vae_args, data_args, ground_truth_dataset, device, mm_num_samples, mm_num_repeats,**kwargs):
    # Currently the configurations of two datasets are almost the same
    replication = kwargs.get('replication', 0)
    replication_times = kwargs.get('replication_times', 1)
    text_encoder = kwargs.get('text_encoder', None)
    if replication_times > 1:
        load_dir = os.path.join(args.load_dir, f"replications_{replication_times}")
        if not os.path.exists(load_dir) or len([f for f in os.listdir(load_dir) if f.endswith('.pt')])!=len(ground_truth_dataset):
            launch_generation_on_multi_gpus(
                args, model, denoiser_args, diffusion_args, vae_model, vae_args, data_args, ground_truth_dataset, 
                mm_num_samples, mm_num_repeats, load_dir, text_encoder=text_encoder, replication_times=replication_times,
            )
    else:
        load_dir = os.path.join(args.load_dir, f"replication_{replication}")
        if not os.path.exists(load_dir) or len([f for f in os.listdir(load_dir) if f.endswith('.pt')])!=len(ground_truth_dataset):
            launch_generation_on_multi_gpus(
                args, model, denoiser_args, diffusion_args, vae_model, vae_args, data_args, ground_truth_dataset, 
                mm_num_samples, mm_num_repeats, load_dir, text_encoder=text_encoder, replication_times=replication_times, 
            )
    dataset = EvaluationDataset(args, model, denoiser_args, diffusion, diffusion_args, vae_model, vae_args, data_args, ground_truth_dataset, device, mm_num_samples, mm_num_repeats,
                                load_dir=load_dir, replication_times=replication_times, replication=replication)
    mm_dataset = MMGeneratedDataset(dataset)

    motion_loader = DataLoader(dataset, batch_size=batch_size, drop_last=True, num_workers=0, shuffle=True)
    mm_motion_loader = DataLoader(mm_dataset, batch_size=1, num_workers=0)

    print('Generated Dataset Loading Completed!!!')

    return motion_loader, mm_motion_loader

# def get_dataset_loader_mp_seq(dataset_path, batch_size, split='train'):
#     from data_loaders.humanml.data.dataset import PrimitiveSequenceDataset
#     dataset = PrimitiveSequenceDataset(dataset_path=dataset_path, split=split, load_data=True)
#
#     loader = DataLoader(
#         dataset, batch_size=batch_size, shuffle=True if split == 'train' else False,
#         num_workers=8, drop_last=True, collate_fn=mp_seq_collate
#     )
#
#     return loader

# if __name__ == "__main__":
#     # trian_loader = get_dataset_loader_mp(dataset_path='./data/mp_data/Canonicalized_h2_f8_num1_fps30/', batch_size=2, split='train')
#     # for i, batch in enumerate(trian_loader):
#     #     print(i)
#     #     print(batch[0], batch[1])
#     #     break

#     from data_loaders.humanml.data.dataset import PrimitiveSequenceDataset
#     dataset = PrimitiveSequenceDataset(dataset_path='./data/mp_data/Canonicalized_h2_f8_num1_fps30/',
#                                          split='train')
#     for _ in tqdm(range(10)):
#         batch = dataset.get_batch(batch_size=64)


def generate_worker(rank, args, all_sample_indices, 
                    model, denoiser_args, diffusion_args, vae_model, vae_args, data_args, ground_truth_dataset, 
                    mm_num_samples, mm_num_repeats, mm_sample_indices, save_dir, replication_times=1, text_encoder=None):
    torch.set_num_threads(4)
    torch.set_num_interop_threads(4)
    torch.cuda.set_device(rank)
    
    sample_indices = all_sample_indices[rank]
    from mld.train_mld_new_v9 import create_gaussian_diffusion
    diffusion = create_gaussian_diffusion(diffusion_args)
    dataset = EvaluationDataset(
        args=args,
        denoiser_model=model.to(rank),
        denoiser_args=denoiser_args,
        diffusion=diffusion,
        diffusion_args=diffusion_args,
        vae_model=vae_model.to(rank),
        vae_args=vae_args,
        data_args=data_args,
        ground_truth_dataset=ground_truth_dataset,
        device=torch.device(f"cuda:{rank}"),
        mm_num_samples=mm_num_samples,
        mm_num_repeats=mm_num_repeats,
        load_data=False,
        text_encoder=text_encoder,
    )

    dataloader = DataLoader(
        ground_truth_dataset, batch_size=1, shuffle=False, num_workers=0
    )

    with torch.no_grad():
        for i, data in enumerate(dataloader):
            index, interaction_text, batch_data, motion_lens = data
            if os.path.exists(f"{save_dir}/motion_{index[0]}.pt") and (index[0] not in mm_sample_indices):
                continue
            if index[0] not in sample_indices:
                continue

            is_mm = index[0] in mm_sample_indices
            motion_sequences = dataset.generate_motion(batch_data, interaction_text, motion_lens, mm_num_repeats if is_mm else replication_times)

            torch.save(motion_sequences, f"{save_dir}/motion_{index[0]}.pt")


import torch.multiprocessing as mp

def launch_generation_on_multi_gpus(args, model, denoiser_args, diffusion_args, vae_model, vae_args, data_args, ground_truth_dataset, 
                    mm_num_samples, mm_num_repeats, save_dir, **kwargs):
    num_gpus = torch.cuda.device_count()
    real_indices = [ground_truth_dataset[i][0] for i in range(len(ground_truth_dataset))]
    random.shuffle(real_indices)
    mm_sample_indices = set(real_indices[:mm_num_samples])
    
    replication_times = kwargs.get('replication_times', 1)
    text_encoder = kwargs.get('text_encoder', None)
    text_encoder_version = kwargs.get('text_encoder_version', 'v1')
    
    if not os.path.exists(save_dir):
        os.makedirs(save_dir, exist_ok=True)
    with open(os.path.join(save_dir, "mm_indices.json"), 'w') as f:
        json.dump(sorted(list(mm_sample_indices)), f)
    

    per_gpu_indices = [[] for _ in range(num_gpus)]
    for idx, sample_id in enumerate(real_indices):
        per_gpu_indices[idx % num_gpus].append(sample_id)

    mp.spawn(
        generate_worker,
        args=(args, per_gpu_indices, model, denoiser_args, diffusion_args, vae_model, vae_args, data_args, ground_truth_dataset, 
                    mm_num_samples, mm_num_repeats, mm_sample_indices, save_dir, replication_times, text_encoder),
        nprocs=num_gpus,
        join=True
    )
