from multiprocessing.pool import Pool

import matplotlib

from utils.pl_utils import data_loader

matplotlib.use('Agg')
from utils.tts_utils import GeneralDenoiser
import glob
import os
import re
import numpy as np
from tqdm import tqdm
import torch.distributed as dist

from modules import speech_transducer
from tasks.base_task import BaseTask, BaseDataset
from utils.hparams import hparams
from utils.indexed_datasets import IndexedDataset
from utils.text_encoder import TokenTextEncoder
import json

import matplotlib.pyplot as plt
import torch
import torch.optim
import torch.utils.data
import torch.nn.functional as F
import utils
import logging
from utils import audio
from utils.pwg_decode_from_mel import generate_wavegan, load_pwg_model

from utils.world_utils import normalize_mel, denormalize_mel
from losses.rnnt_loss import RNNTLossTTS
from scipy.io import wavfile


class SpeechTransducerDataset(BaseDataset):
    def __init__(self, data_dir, phone_encoder, prefix, hparams, shuffle=False):
        super().__init__(data_dir, prefix, hparams, shuffle)
        self.phone_encoder = phone_encoder
        self.data = None
        self.idx2key = np.load(f'{self.data_dir}/{self.prefix}_all_keys.npy')
        self.sizes = np.load(f'{self.data_dir}/{self.prefix}_lengths.npy')
        self.train_idx2key = np.load(f'{self.data_dir}/train_all_keys.npy')
        self.use_indexed_ds = hparams['indexed_ds']
        self.indexed_bs = None

        # mels stats (frame level)
        mels = np.load(f'{self.data_dir}/train_mels.npy', allow_pickle=True)
        mels = np.concatenate(mels, 0)
        hparams['mel_mean'] = self.mel_mean = np.mean(mels, 0)
        hparams['mel_std'] = self.mel_std = np.std(mels, 0)

    def _get_item(self, index):
        if not self.use_indexed_ds:
            key = self.idx2key[index]
            item = np.load(f'{self.data_dir}/{self.prefix}/{key}.npy', allow_pickle=True).item()
        else:
            if self.indexed_bs is None:
                self.indexed_bs = IndexedDataset(f'{self.data_dir}/{self.prefix}')
            item = self.indexed_bs[index]
        return item

    def _make_pad_mask(self, target_lens, maxlen=None):
        """    
        Examples:
            lengths = [5, 3, 2]
            make_pad_mask(lengths)
            masks = [[0, 0, 0, 0 ,0],
                     [0, 0, 0, 1, 1],
                     [0, 0, 1, 1, 1]]
        Returns:
            bool
        """
        if not isinstance(target_lens, list):
            target_lens = target_lens.tolist()
        bs = int(len(target_lens))
        if maxlen is None:
            maxlen = int(max(target_lens))
        seq_range = torch.arange(0, maxlen, dtype=torch.int64)  # (maxlen,)
        seq_range_expand = seq_range.unsqueeze(0).expand(bs, -1)  # (B, maxlen)
        seq_length_expand = seq_range_expand.new(target_lens).unsqueeze(-1) # (B, 1)
        mask = seq_range_expand >= seq_length_expand
        return mask

    def _make_attention_map(self, dur):
        """
        Args:
            dur (tensor): (L,), frames
        Returns:
            alpha: attention map, (L, sum(L))
            phi: shift phoneme labels, (L, sum(L))
                 shift to next phoneme if phi[i] is 1.
            attention_map = alpha * (1-phi)
        """
        dur_cum_b = dur.cumsum(dim=0)
        dur_cum_a = F.pad(dur_cum_b, (1,0))[:-1]
        max_len = max(dur_cum_b)
        mask_b = (~self._make_pad_mask(dur_cum_b+1, maxlen=max_len)).float()
        mask_a = (~self._make_pad_mask(dur_cum_a, maxlen=max_len)).float()
        mask = mask_b - mask_a
        phi = (F.pad(mask[:,1:], (0, 1)) - mask) * mask
        phi = phi.bool().float()
        phi[-1,-1] = 0
        return mask, phi

    def __getitem__(self, index):
        hparams = self.hparams
        key = self.idx2key[index]
        item = self._get_item(index)
        # input / output
        phone = torch.LongTensor(item['phone'])
        spec = torch.tensor(normalize_mel(item['mel'], hparams))
        dur = torch.tensor(item['dur'])
        alpha, phi = self._make_attention_map(dur)

        sample = {
            "id": index,
            "utt_id": key,
            "text": item['txt'],
            "source": phone,
            "target": spec[:hparams['max_frames']],
            "alpha": alpha[:,:hparams['max_frames']],
            "phi": phi[:,:hparams['max_frames']],
        }
        return sample

    def collater(self, samples):
        if len(samples) == 0:
            return {}
        pad_idx = self.phone_encoder.pad()

        id = torch.LongTensor([s['id'] for s in samples])
        utt_id = [s['utt_id'] for s in samples]
        text = [s['text'] for s in samples]
        src_tokens = utils.collate_1d([s['source'] for s in samples], pad_idx)
        targets = utils.collate_2d([s['target'] for s in samples], pad_idx)
        prev_output_mels = utils.collate_2d([s['target'] for s in samples], pad_idx, shift_right=True)
        alphas = utils.collate_2ds([s['alpha'] for s in samples], pad_idx)
        phis = utils.collate_2ds([s['phi'] for s in samples], pad_idx)

        # sort by descending source length
        src_lengths = torch.LongTensor([s['source'].numel() for s in samples])
        target_lengths = torch.LongTensor([s['target'].shape[0] for s in samples])
        target_lengths, sort_order = target_lengths.sort(descending=True)
        targets = targets.index_select(0, sort_order)
        prev_output_mels = prev_output_mels.index_select(0, sort_order)
        src_tokens = src_tokens.index_select(0, sort_order)
        src_lengths = src_lengths.index_select(0, sort_order)
        alphas = alphas.index_select(0, sort_order)
        phis = phis.index_select(0, sort_order)
        id = id.index_select(0, sort_order)
        utt_id = [utt_id[i] for i in sort_order]
        text = [text[i] for i in sort_order]
        ntokens = sum(len(s['source']) for s in samples)
        nmels = sum(len(s['target']) for s in samples)

        batch = {
            'id': id,                       # (B,)
            'utt_id': utt_id,               # list of text
            'nsamples': len(samples),       # scaler
            'ntokens': ntokens,             # scaler
            'nmels': nmels,                 # scaler
            'text': text,                   # list of text
            'src_tokens': src_tokens,       # (B, T)
            'src_lengths': src_lengths,     # (B,)
            'targets': targets,             # (B, U, n_mels)
            'target_lengths': target_lengths,       # (B,)
            'prev_output_mels': prev_output_mels,   # (B, U, n_mels)
            'alphas': alphas,               # (B, T, U)
            'phis': phis,                   # (B, T, U)
        }
        return batch

    @property
    def num_workers(self):
        return int(os.getenv('NUM_WORKERS', 1))


class RSQRTSchedule(object):
    def __init__(self, optimizer):
        super().__init__()
        self.optimizer = optimizer
        self.constant_lr = hparams['lr']
        self.warmup_updates = hparams['warmup_updates']
        self.hidden_size = hparams['hidden_size']
        self.lr = hparams['lr']
        for param_group in optimizer.param_groups:
            param_group['lr'] = self.lr
        self.step(0)

    def step(self, num_updates):
        constant_lr = self.constant_lr
        warmup = min(num_updates / self.warmup_updates, 1.0)
        rsqrt_decay = max(self.warmup_updates, num_updates) ** -0.5
        rsqrt_hidden = self.hidden_size ** -0.5
        self.lr = max(constant_lr * warmup * rsqrt_decay * rsqrt_hidden, 1e-7)
        for param_group in self.optimizer.param_groups:
            param_group['lr'] = self.lr
        return self.lr

    def get_lr(self):
        return self.optimizer.param_groups[0]['lr']



class SpeechTransducerTask(BaseTask):
    def __init__(self, *args, **kwargs):
        self.arch = hparams['arch']
        if isinstance(self.arch, str):
            self.arch = list(map(int, self.arch.strip().split()))
        if self.arch is not None:
            self.num_heads = utils.get_num_heads(self.arch[hparams['enc_layers']:])
        self.vocoder = None
        self.phone_encoder = self.build_phone_encoder(hparams['data_dir'])
        self.padding_idx = self.phone_encoder.pad()
        self.eos_idx = self.phone_encoder.eos()
        self.seg_idx = self.phone_encoder.seg()
        self.saving_result_pool = None
        self.saving_results_futures = None
        self.stats = {}
        super().__init__(*args, **kwargs)
        self.rnnt_loss_tts = RNNTLossTTS()

    @data_loader
    def train_dataloader(self):
        train_dataset = SpeechTransducerDataset(hparams['data_dir'], self.phone_encoder,
                                          hparams['train_set_name'], hparams, shuffle=True)
        return self.build_dataloader(train_dataset, True, self.max_tokens, self.max_sentences,
                                     endless=hparams['endless_ds'])

    @data_loader
    def val_dataloader(self):
        valid_dataset = SpeechTransducerDataset(hparams['data_dir'], self.phone_encoder,
                                          hparams['valid_set_name'], hparams, shuffle=False)
        return self.build_dataloader(valid_dataset, False, self.max_eval_tokens, self.max_eval_sentences)

    @data_loader
    def test_dataloader(self):
        test_dataset = SpeechTransducerDataset(hparams['data_dir'], self.phone_encoder,
                                         hparams['valid_set_name'], hparams, shuffle=False)
        return self.build_dataloader(test_dataset, False, self.max_eval_tokens, self.max_eval_sentences)

    def build_dataloader(self, dataset, shuffle, max_tokens=None, max_sentences=None,
                         required_batch_size_multiple=-1, endless=False):
        if required_batch_size_multiple == -1:
            required_batch_size_multiple = torch.cuda.device_count()

        def shuffle_batches(batches):
            np.random.shuffle(batches)
            return batches

        if max_tokens is not None:
            max_tokens *= torch.cuda.device_count()
        if max_sentences is not None:
            max_sentences *= torch.cuda.device_count()
        indices = dataset.ordered_indices()
        batch_sampler = utils.batch_by_size(
            indices, dataset.num_tokens, max_tokens=max_tokens, max_sentences=max_sentences,
            required_batch_size_multiple=required_batch_size_multiple,
        )

        if shuffle:
            batches = shuffle_batches(list(batch_sampler))
            if endless:
                batches = [b for _ in range(1000) for b in shuffle_batches(list(batch_sampler))]
        else:
            batches = batch_sampler
            if endless:
                batches = [b for _ in range(1000) for b in batches]
        num_workers = dataset.num_workers
        if self.trainer.use_ddp:
            num_replicas = dist.get_world_size()
            rank = dist.get_rank()
            batches = [x[rank::num_replicas] for x in batches if len(x) % num_replicas == 0]
        return torch.utils.data.DataLoader(dataset,
                                           collate_fn=dataset.collater,
                                           batch_sampler=batches,
                                           num_workers=num_workers,
                                           pin_memory=False)

    def build_phone_encoder(self, data_dir):
        phone_list_file = os.path.join(data_dir, 'phone_set.json')
        phone_list = json.load(open(phone_list_file))
        return TokenTextEncoder(None, vocab_list=phone_list)    
        
    def build_model(self):
        arch = self.arch
        model = speech_transducer.SpeechTransducer(arch, self.phone_encoder)
        return model    

    def build_scheduler(self, optimizer):
        return RSQRTSchedule(optimizer)

    def build_optimizer(self, model):
        self.optimizer = optimizer = torch.optim.AdamW(
            model.parameters(),
            lr=hparams['lr'],
            betas=(hparams['optimizer_adam_beta1'], hparams['optimizer_adam_beta2']),
            weight_decay=hparams['weight_decay'])
        return optimizer

    def _training_step(self, sample, batch_idx, _):
        inputs = sample['src_tokens']
        prev_output_mels = sample['prev_output_mels']
        targets = sample['targets']
        hparams['global_steps'] = self.global_step
        outputs = self.model(inputs, prev_output_mels, targets)
        alphas = sample['alphas']
        phis = sample['phis']
        src_lengths = sample['src_lengths']
        target_lengths = sample['target_lengths']
        loss_output = self.loss(outputs, targets, src_lengths, target_lengths, alphas, phis)
        total_loss = sum([v for k, v in loss_output.items() if k != 'ref_mel'])
        loss_output['batch_size'] = inputs.size(0)
        return total_loss, loss_output

    def loss(self, outputs, targets, output_lengths, target_lengths, alphas, phis):
        mel_loss, diag_cumsum_constrain, diag_constrain, ref_mel_loss = self.rnnt_loss_tts(
            outputs, 
            targets, 
            output_lengths, 
            target_lengths, 
            alphas, 
            phis
        )
        return {
            'mel': mel_loss * hparams['lambda_mel'],
            'ref_mel': ref_mel_loss,
            'diag_cumsum': diag_cumsum_constrain * hparams['lambda_diag_cumsum'],
            'diag': diag_constrain * hparams['lambda_diag'],
        }

    def validation_step(self, sample, batch_idx):
        inputs = sample['src_tokens']
        prev_output_mels = sample['prev_output_mels']
        targets = sample['targets']
        hparams['global_steps'] = self.global_step
        output = self.model(inputs, prev_output_mels, targets)
        alphas = sample['alphas']
        phis = sample['phis']
        src_lengths = sample['src_lengths']
        target_lengths = sample['target_lengths']
        outputs = {}
        outputs['losses'] = self.loss(output, targets, src_lengths, target_lengths, alphas, phis)
        outputs['total_loss'] = sum([v for k, v in outputs['losses'].items() if k != 'ref_mel'])
        outputs['nmels'] = sample['nmels']
        outputs['nsamples'] = sample['nsamples']
        outputs = utils.tensors_to_scalars(outputs)
        return outputs

    def _validation_end(self, outputs):
        all_losses_meter = {
            'total_loss': utils.AvgrageMeter(),
        }
        for output in outputs:
            n = output['nmels']
            for k, v in output['losses'].items():
                if k not in all_losses_meter:
                    all_losses_meter[k] = utils.AvgrageMeter()
                all_losses_meter[k].update(v, n)
            all_losses_meter['total_loss'].update(output['total_loss'], n)
        return {k: round(v.avg, 4) for k, v in all_losses_meter.items()}

    def test_step(self, sample, batch_idx):
        # import pdb; pdb.set_trace()
        # if batch_idx < 10:
        self.test_step_tts(sample, batch_idx)

    def test_end(self, outputs):
        self.saving_result_pool.close()
        [f.get() for f in tqdm(self.saving_results_futures)]
        self.saving_result_pool.join()
        return {}

    def prepare_vocoder(self):
        if self.vocoder is None:
            if hparams['vocoder'] == 'pwg':
                if hparams['vocoder_ckpt'] == '':
                    base_dir = 'wavegan_pretrained'
                    ckpts = glob.glob(f'{base_dir}/checkpoint-*steps.pkl')
                    ckpt = sorted(ckpts, key=
                    lambda x: int(re.findall(f'{base_dir}/checkpoint-(\d+)steps.pkl', x)[0]))[-1]
                    config_path = f'{base_dir}/config.yaml'
                else:
                    base_dir = hparams['vocoder_ckpt']
                    config_path = f'{base_dir}/config.yaml'
                    ckpt = sorted(glob.glob(f'{base_dir}/model_ckpt_steps_*.ckpt'), key=
                    lambda x: int(re.findall(f'{base_dir}/model_ckpt_steps_(\d+).ckpt', x)[0]))[-1]
                print('| load wavegan: ', ckpt)
                self.vocoder = load_pwg_model(
                    config_path=config_path,
                    checkpoint_path=ckpt,
                    stats_path=f'{base_dir}/stats.h5',
                )
                self.denoiser = GeneralDenoiser()

    def inv_spec(self, spec, pitch=None, noise_spec=None):
        """

        :param spec: [T, 80]
        :return:
        """
        if hparams['vocoder'] == 'pwg':
            wav_out = generate_wavegan(spec, *self.vocoder, profile=hparams['profile_infer'])
            if hparams['gen_wav_denoise']:
                noise_out = generate_wavegan(noise_spec, *self.vocoder)[None, :] \
                    if noise_spec is not None else None
                wav_out = self.denoiser(wav_out[None, :], noise_out)[0, 0]
            wav_out = wav_out.cpu().numpy()
            return wav_out

    def test_step_tts(self, sample, batch_idx):
        def decode(token, tgt_mel, alpha, phi, utt_id, text, shift_threshold=0.5):
            """
            Args:
                token: (T,),
                tgt_mel: (U,),
                alpha: (T, U),
                phi: (T, U),
            """
            shift_tag = np.log(shift_threshold / (1 - shift_threshold))
            max_shift_length = None
            incremental_state = None
            # load model
            model = self.model
            # estimate length
            input_length = len(token)
            max_output_length = input_length * 5 + 150
            # invert to tensor, (T) -> (1, T)
            token = token.unsqueeze(0)
            # forward text encoder, (1, T) -> (1, T, H)
            text_encoder_outputs = model.forward_text_encoder(token)
            text_encoder_outputs = text_encoder_outputs['encoder_out'].transpose(0, 1)
            # prev_target, (1, U_max, n_mels)
            speech_encoder_inputs = token.new(1, max_output_length + 1, hparams['audio_num_mel_bins']).fill_(0).float()
            # restored output mels
            output_mels = token.new(1, 0, hparams['audio_num_mel_bins']).fill_(0).float()

            output_length = 0
            phone_step = 0
            shift_length = 0
            attention_map = list()
            pre_output_length = 0
            # print(f"phone_step\tmel_step\tphi\ttoken_id\ttoken")
            # import pdb; pdb.set_trace()
            while phone_step < input_length and output_length < max_output_length:
                ## use the 'phone_step'-th phone
                ## text_encoder_output (tensor): (1, Tx, H)
                text_encoder_output = text_encoder_outputs[:,:phone_step+1]
                ## speech_encoder_input (tensor): (1, Ux, n_mels)
                speech_encoder_input = speech_encoder_inputs[:, :output_length+1]
                ## speech_encoder_output (tensor): (1, Ux, H)
                speech_encoder_output = model.forward_speech_encoder(speech_encoder_input, incremental_state=incremental_state)
                ## forward tts joint network, (B, Tx, Ux, n_mels+1)
                joint_network_outputs = model.tts_joint_network(text_encoder_output, speech_encoder_output)
                joint_network_output = joint_network_outputs[0, -1, -1]
                mel, phi = joint_network_output[:-1], joint_network_output[-1]
                ## early stop
                if self.phone_encoder._id_to_token[token[0][phone_step].item()] == '<EOS>':
                    break
                if phi > shift_tag:  # shift to next token
                    phone_step += 1
                    shift_length = 0
                    line = torch.zeros((1, max_output_length))
                    line[:,pre_output_length:output_length] = 1.
                    attention_map += [line]
                    pre_output_length = output_length
                    continue
                else:
                    output_length += 1
                    speech_encoder_inputs[0, output_length] = mel
                    output_mels = torch.cat((output_mels, mel.view(1, 1, -1)), dim=1)
                    shift_length += 1
                    if max_shift_length is not None and shift_length > max_shift_length:
                        phone_step += 1
                        shift_length = 0
                        line = torch.zeros((1, max_output_length))
                        line[:,pre_output_length:output_length] = 1.
                        attention_map += [line]
                        pre_output_length = output_length
                        print(f">>> {phone_step-1} out of threshold {max_shift_length}, automatic skip to next phone !!!")
                    continue
            
            # gen wave
            gen_dir = os.path.join(hparams['work_dir'], f'generated_{self.trainer.global_step}')
            os.makedirs(gen_dir, exist_ok=True)
            os.makedirs(f'{gen_dir}/wavs', exist_ok=True)
            spec = output_mels[0].cpu().numpy()
            spec = denormalize_mel(spec, hparams)
            wav_out = generate_wavegan(spec, *self.vocoder, profile=hparams['profile_infer'])
            tgt_mel = tgt_mel.cpu().numpy()
            tgt_mel = denormalize_mel(tgt_mel, hparams)
            target_wav_out = generate_wavegan(tgt_mel, *self.vocoder, profile=hparams['profile_infer'])
            self.save_result(wav_out, 'P', utt_id, text, gen_dir)
            self.save_result(target_wav_out, 'G', utt_id, text, gen_dir)

        self.prepare_vocoder()
        logging.info('inferring batch {} with {} samples'.format(batch_idx, sample['nsamples']))
        with utils.Timer('trans_tts', print_time=hparams['profile_infer']):
            inputs = sample['src_tokens']
            targets = sample['targets']
            src_lengths = sample['src_lengths']
            target_lengths = sample['target_lengths']
            alphas = sample['alphas']
            phis = sample['phis']
            utt_ids = sample.get('utt_id')
            texts = sample.get('text')
            for input, target, alpha, phi, utt_id, text in zip(inputs, targets, alphas, phis, utt_ids, texts):
                decode(input, target, alpha, phi, utt_id, text)

    ##########
    # utils
    ##########
    @staticmethod
    def save_result(wav_out, prefix, utt_id, text, gen_dir):
        base_fn = f'[{prefix}][{utt_id}]'
        base_fn += text.replace(":", "%3A")[:80]
        wavfile.write(f'{gen_dir}/wavs/{base_fn}.wav', hparams['audio_sample_rate'], wav_out.cpu().numpy())



if __name__ == '__main__':
    SpeechTransducerTask.start()

