import torch
import enum
from data_loaders.humanml.networks.modules import *
from data_loaders.humanml.networks.trainers import CompTrainerV6
from torch.utils.data import Dataset, DataLoader
from os.path import join as pjoin
from tqdm import tqdm
from utils import dist_util

import logging

# Configure logging
logging.basicConfig(filename='debug.log', level=logging.DEBUG, format='%(asctime)s - %(levelname)s - %(message)s')

class ModelVarType(enum.Enum):
    """
    What is used as the model's output variance.

    The LEARNED_RANGE option has been added to allow the model to predict
    values between FIXED_SMALL and FIXED_LARGE, making its job easier.
    """

    LEARNED = enum.auto()
    FIXED_SMALL = enum.auto()
    FIXED_LARGE = enum.auto()
    LEARNED_RANGE = enum.auto()

def build_models(opt):
    if opt.text_enc_mod == 'bigru':
        text_encoder = TextEncoderBiGRU(word_size=opt.dim_word,
                                        pos_size=opt.dim_pos_ohot,
                                        hidden_size=opt.dim_text_hidden,
                                        device=opt.device)
        text_size = opt.dim_text_hidden * 2
    else:
        raise Exception("Text Encoder Mode not Recognized!!!")

    seq_prior = TextDecoder(text_size=text_size,
                            input_size=opt.dim_att_vec + opt.dim_movement_latent,
                            output_size=opt.dim_z,
                            hidden_size=opt.dim_pri_hidden,
                            n_layers=opt.n_layers_pri)


    seq_decoder = TextVAEDecoder(text_size=text_size,
                                 input_size=opt.dim_att_vec + opt.dim_z + opt.dim_movement_latent,
                                 output_size=opt.dim_movement_latent,
                                 hidden_size=opt.dim_dec_hidden,
                                 n_layers=opt.n_layers_dec)

    att_layer = AttLayer(query_dim=opt.dim_pos_hidden,
                         key_dim=text_size,
                         value_dim=opt.dim_att_vec)

    movement_enc = MovementConvEncoder(opt.dim_pose - 4, opt.dim_movement_enc_hidden, opt.dim_movement_latent)
    movement_dec = MovementConvDecoder(opt.dim_movement_latent, opt.dim_movement_dec_hidden, opt.dim_pose)

    len_estimator = MotionLenEstimatorBiGRU(opt.dim_word, opt.dim_pos_ohot, 512, opt.num_classes)

    # latent_dis = LatentDis(input_size=opt.dim_z * 2)
    checkpoints = torch.load(pjoin(opt.checkpoints_dir, opt.dataset_name, 'length_est_bigru', 'model', 'latest.tar'), map_location=opt.device)
    len_estimator.load_state_dict(checkpoints['estimator'])
    len_estimator.to(opt.device)
    len_estimator.eval()

    # return text_encoder, text_decoder, att_layer, vae_pri, vae_dec, vae_pos, motion_dis, movement_dis, latent_dis
    return text_encoder, seq_prior, seq_decoder, att_layer, movement_enc, movement_dec, len_estimator

class CompV6GeneratedDataset(Dataset):

    def __init__(self, opt, dataset, w_vectorizer, mm_num_samples, mm_num_repeats):
        assert mm_num_samples < len(dataset)
        print(opt.model_dir)

        dataloader = DataLoader(dataset, batch_size=1, num_workers=1, shuffle=True)
        text_enc, seq_pri, seq_dec, att_layer, mov_enc, mov_dec, len_estimator = build_models(opt)
        trainer = CompTrainerV6(opt, text_enc, seq_pri, seq_dec, att_layer, mov_dec, mov_enc=mov_enc)
        epoch, it, sub_ep, schedule_len = trainer.load(pjoin(opt.model_dir, opt.which_epoch + '.tar'))
        generated_motion = []
        mm_generated_motions = []
        mm_idxs = np.random.choice(len(dataset), mm_num_samples, replace=False)
        mm_idxs = np.sort(mm_idxs)
        min_mov_length = 10 if opt.dataset_name == 't2m' else 6
        # print(mm_idxs)

        print('Loading model: Epoch %03d Schedule_len %03d' % (epoch, schedule_len))
        trainer.eval_mode()
        trainer.to(opt.device)
        with torch.no_grad():
            for i, data in tqdm(enumerate(dataloader)):
                word_emb, pos_ohot, caption, cap_lens, motions, m_lens, tokens = data
                tokens = tokens[0].split('_')
                word_emb = word_emb.detach().to(opt.device).float()
                pos_ohot = pos_ohot.detach().to(opt.device).float()

                pred_dis = len_estimator(word_emb, pos_ohot, cap_lens)
                pred_dis = nn.Softmax(-1)(pred_dis).squeeze()

                mm_num_now = len(mm_generated_motions)
                is_mm = True if ((mm_num_now < mm_num_samples) and (i == mm_idxs[mm_num_now])) else False

                repeat_times = mm_num_repeats if is_mm else 1
                mm_motions = []
                for t in range(repeat_times):
                    mov_length = torch.multinomial(pred_dis, 1, replacement=True)
                    if mov_length < min_mov_length:
                        mov_length = torch.multinomial(pred_dis, 1, replacement=True)
                    if mov_length < min_mov_length:
                        mov_length = torch.multinomial(pred_dis, 1, replacement=True)

                    m_lens = mov_length * opt.unit_length
                    pred_motions, _, _ = trainer.generate(word_emb, pos_ohot, cap_lens, m_lens,
                                                          m_lens[0]//opt.unit_length, opt.dim_pose)
                    if t == 0:
                        # print(m_lens)
                        # print(text_data)
                        sub_dict = {'motion': pred_motions[0].cpu().numpy(),
                                    'length': m_lens[0].item(),
                                    'cap_len': cap_lens[0].item(),
                                    'caption': caption[0],
                                    'tokens': tokens}
                        generated_motion.append(sub_dict)

                    if is_mm:
                        mm_motions.append({
                            'motion': pred_motions[0].cpu().numpy(),
                            'length': m_lens[0].item()
                        })
                if is_mm:
                    mm_generated_motions.append({'caption': caption[0],
                                                 'tokens': tokens,
                                                 'cap_len': cap_lens[0].item(),
                                                 'mm_motions': mm_motions})

        self.generated_motion = generated_motion
        self.mm_generated_motion = mm_generated_motions
        self.opt = opt
        self.w_vectorizer = w_vectorizer


    def __len__(self):
        return len(self.generated_motion)


    def __getitem__(self, item):
        data = self.generated_motion[item]
        motion, m_length, caption, tokens = data['motion'], data['length'], data['caption'], data['tokens']
        sent_len = data['cap_len']

        pos_one_hots = []
        word_embeddings = []
        for token in tokens:
            word_emb, pos_oh = self.w_vectorizer[token]
            pos_one_hots.append(pos_oh[None, :])
            word_embeddings.append(word_emb[None, :])
        pos_one_hots = np.concatenate(pos_one_hots, axis=0)
        word_embeddings = np.concatenate(word_embeddings, axis=0)

        if m_length < self.opt.max_motion_length:
            motion = np.concatenate([motion,
                                     np.zeros((self.opt.max_motion_length - m_length, motion.shape[1]))
                                     ], axis=0)
        return word_embeddings, pos_one_hots, caption, sent_len, motion, m_length, '_'.join(tokens)


class CompMDMPGeneratedDataset(Dataset):

    def __init__(self, model, diffusion, dataloader, mm_num_samples, mm_num_repeats, max_motion_length, num_samples_limit, start_idx, scale=1.):
        self.dataloader = dataloader
        # self.lv = diffusion.model_var_type in [ModelVarType.LEARNED, ModelVarType.LEARNED_RANGE]
        self.lv = diffusion.model_var_type.value in [ModelVarType.LEARNED.value, ModelVarType.LEARNED_RANGE.value]
        self.dataset = dataloader.dataset
        assert mm_num_samples < len(dataloader.dataset)
        use_ddim = False  # FIXME - hardcoded
        clip_denoised = False  # FIXME - hardcoded
        self.max_motion_length = max_motion_length
        sample_fn = (
            diffusion.p_sample_loop if not use_ddim else diffusion.ddim_sample_loop
        )

        real_num_batches = len(dataloader)
        if num_samples_limit is not None:
            real_num_batches = num_samples_limit // dataloader.batch_size + 1
        logging.debug('real_num_batches: %d', real_num_batches)

        generated_motion = []
        mm_generated_motions = []
        if mm_num_samples > 0:
            mm_idxs = np.random.choice(real_num_batches, mm_num_samples // dataloader.batch_size +1, replace=False)
            mm_idxs = np.sort(mm_idxs)
        else:
            mm_idxs = []
        logging.debug('mm_idxs: %s', mm_idxs)

        model.eval()


        with torch.no_grad():
            for i, (motion, model_kwargs) in tqdm(enumerate(dataloader)):
                motion = motion.to(dist_util.dev())
                if num_samples_limit is not None and len(generated_motion) >= num_samples_limit:
                    break

                tokens = [t.split('_') for t in model_kwargs['y']['tokens']]

                # add CFG scale to batch
                if scale != 1.:
                    model_kwargs['y']['scale'] = torch.ones(motion.shape[0],
                                                            device=dist_util.dev()) * scale

                model_kwargs['y']['motion_embed'] = motion
                model_kwargs['y']['motion_embed_mask'] = torch.ones_like(motion, dtype=torch.bool, device=motion.device)
                model_kwargs['y']['motion_embed_mask'][:, :, :, start_idx:] = False
                mm_num_now = len(mm_generated_motions) // dataloader.batch_size
                is_mm = i in mm_idxs
                repeat_times = mm_num_repeats if is_mm else 1
                mm_motions = []
                for t in range(repeat_times):

                    if self.lv:
                        sample, log_variance = sample_fn(
                            model=model,
                            shape=motion.shape,
                            clip_denoised=clip_denoised,
                            model_kwargs=model_kwargs,
                            skip_timesteps=0,  # 0 is the default value - i.e. don't skip any step
                            init_image=None,
                            progress=True,
                            dump_steps=None,
                            noise=None,
                            const_noise=False,
                            # when experimenting guidance_scale we want to nutrileze the effect of noise on generation
                        )
                    else:
                        sample = sample_fn(
                            model=model,
                            shape=motion.shape,
                            clip_denoised=clip_denoised,
                            model_kwargs=model_kwargs,
                            skip_timesteps=0,  # 0 is the default value - i.e. don't skip any step
                            init_image=None,
                            progress=False,
                            dump_steps=None,
                            noise=None,
                            const_noise=False,
                            # when experimenting guidance_scale we want to nutrileze the effect of noise on generation
                        )

                    if t == 0:
                        if self.lv:
                            sub_dicts = [{
                            'input_motion': motion[bs_i].squeeze().permute(1, 0).cpu().numpy(),
                            'motion': sample[bs_i].squeeze().permute(1, 0).cpu().numpy(),
                            'log_variance': log_variance[bs_i].squeeze().permute(1, 0).cpu().numpy(),
                            'length': model_kwargs['y']['lengths'][bs_i].cpu().numpy(),
                            'caption': model_kwargs['y']['text'][bs_i],
                            'tokens': tokens[bs_i],
                            'cap_len': tokens[bs_i].index('eos/OTHER') + 1, 
                            } for bs_i in range(dataloader.batch_size)]
                        else:
                            sub_dicts = [{
                                'input_motion': motion[bs_i].squeeze().permute(1, 0).cpu().numpy(),
                                'motion': sample[bs_i].squeeze().permute(1, 0).cpu().numpy(),
                                'length': model_kwargs['y']['lengths'][bs_i].cpu().numpy(),
                                'caption': model_kwargs['y']['text'][bs_i],
                                'tokens': tokens[bs_i],
                                'cap_len': tokens[bs_i].index('eos/OTHER') + 1, 
                                } for bs_i in range(dataloader.batch_size)]
                        # print(f"Generated sub_dicts: {sub_dicts}")  # Debug print
                        generated_motion += sub_dicts

                    if is_mm:
                        if diffusion.model_var_type in [ModelVarType.LEARNED, ModelVarType.LEARNED_RANGE]:
                            mm_motions += [{'input_motion': motion[bs_i].squeeze().permute(1, 0).cpu().numpy(),
                                        'motion': sample[bs_i].squeeze().permute(1, 0).cpu().numpy(),
                                        'log_variance': log_variance[bs_i].squeeze().permute(1, 0).cpu().numpy(),
                                        'length': model_kwargs['y']['lengths'][bs_i].cpu().numpy(),
                                        } for bs_i in range(dataloader.batch_size)]
                        else:
                            mm_motions += [{'input_motion': motion[bs_i].squeeze().permute(1, 0).cpu().numpy(),
                                        'motion': sample[bs_i].squeeze().permute(1, 0).cpu().numpy(),
                                        'length': model_kwargs['y']['lengths'][bs_i].cpu().numpy(),
                                        } for bs_i in range(dataloader.batch_size)]

                if is_mm:
                    mm_generated_motions += [{
                                    'caption': model_kwargs['y']['text'][bs_i],
                                    'tokens': tokens[bs_i],
                                    'cap_len': len(tokens[bs_i]),
                                    'mm_motions': mm_motions[bs_i::dataloader.batch_size],  # collect all 10 repeats from the (32*10) generated motions
                                    } for bs_i in range(dataloader.batch_size)]


        self.generated_motion = generated_motion
        self.mm_generated_motion = mm_generated_motions
        self.w_vectorizer = dataloader.dataset.w_vectorizer


    def __len__(self):
        length = len(self.generated_motion)
        logging.debug('len of generated motion: %d', length)
        return length

    def __getitem__(self, item):
        # logging.debug("Fetching item at index %d", item)
        data = self.generated_motion[item]
        # logging.debug("Data at index %d: %s", item, data)

        input_motion, motion, m_length, caption, tokens = data['input_motion'], data['motion'], data['length'], data['caption'], data['tokens']
        if self.lv:
            log_variance = data['log_variance']
            # logging.debug("Retrieved log_variance: %s", log_variance is not None)
        else:
            log_variance = None
        sent_len = data['cap_len']

        if self.dataset.mode == 'eval':
            # normed_motion = motion
            # denormed_motion = self.dataset.t2m_dataset.inv_transform(normed_motion)
            # renormed_motion = (denormed_motion - self.dataset.mean_for_eval) / self.dataset.std_for_eval  # according to T2M norms
            # motion = renormed_motion
            # # This step is needed because T2M evaluators expect their norm convention

            # normed_input_motion = input_motion
            # denormed_input_motion = self.dataset.t2m_dataset.inv_transform(normed_input_motion)
            # renormed_input_motion = (denormed_input_motion - self.dataset.mean_for_eval) / self.dataset.std_for_eval
            # input_motion = renormed_input_motion

            motion = self.dataset.t2m_dataset.inv_transform(motion)
            input_motion = self.dataset.t2m_dataset.inv_transform(input_motion)

            # if log_variance is not None:
            #     normed_log_variance = log_variance
            #     denormed_log_variance = self.dataset.t2m_dataset.inv_transform(normed_log_variance)
            #     renormed_log_variance = (denormed_log_variance - self.dataset.mean_for_eval) / self.dataset.std_for_eval
            #     log_variance = renormed_log_variance

        pos_one_hots = []
        word_embeddings = []
        for token in tokens:
            word_emb, pos_oh = self.w_vectorizer[token]
            pos_one_hots.append(pos_oh[None, :])
            word_embeddings.append(word_emb[None, :])
        pos_one_hots = np.concatenate(pos_one_hots, axis=0)
        word_embeddings = np.concatenate(word_embeddings, axis=0)

        if self.lv:
            # logging.debug('Returning with log_variance')
            return word_embeddings, pos_one_hots, caption, sent_len, input_motion, motion, log_variance, m_length, '_'.join(tokens)
        else:
            # logging.debug('Returning without log_variance')
            log_variance = np.zeros_like(input_motion)
            return word_embeddings, pos_one_hots, caption, sent_len, input_motion, motion, log_variance, m_length, '_'.join(tokens)