import utils.commons.single_thread_env  # NOQA
import re
import uuid
import glob
import random
import traceback
from torch.utils.data import TensorDataset
from tqdm import tqdm
import os
import ray
import torch
import torch.nn.functional as F
import torch.optim
import torch.utils.data
from modules.tts.commons.align_ops import clip_seq_to_multiple
from monotonic_align import get_best_alignments
from tasks.tts.fs_adv import build_disc
from tasks.tts.speech_base import SpeechBaseTask
from .dataset import UnsuperTTSBidirDataset, UnsuperTTSBidirBTDataset, UnsuperTTSVCDataset
from .models.lm import TransformerDecoder
from .models.speech_generator.model import SpeechGenerator
from .models.tacotron2.model import RNNTranslator
from utils.audio.align import mel2token_to_dur
from utils.audio.io import save_wav
from utils.commons.ckpt_utils import load_ckpt
from utils.commons.dataset_utils import BaseConcatDataset, collate_1d_or_2d
from utils.commons.hparams import hparams
from utils.commons.indexed_datasets import IndexedDataset2Builder, IndexedDataset2
from utils.commons.ray_utils import ray_init, ray_shutdown
from utils.commons.tensor_utils import move_to_cuda
from utils.nn.model_utils import print_arch
from utils.nn.schedulers import WarmupSchedule
from utils.os_utils import copy_file, remove_file
from utils.text.text_encoder import build_token_encoder, PUNCS


def get_txt_dict(hparams):
    data_dir = hparams['binary_data_dir']
    if hparams['text_type'] == 'ipa':
        txt_token_encoder = build_token_encoder(f'{data_dir}/phone_set.json')
    else:
        txt_token_encoder = build_token_encoder(f'{data_dir}/chr_set.json')
    return txt_token_encoder


def text_token_to_str(tokens, txt_encoder):
    return "".join(txt_encoder.decode_list([x for x in tokens.numpy() if x > 2]))


def text_norm(s):
    return re.sub(r"\s+", " ", re.sub(f"[{PUNCS}]", r" ", s.lower())).strip()


def inference_job(worker_i, total_workers, get_models, task_hparams,
                  s2t_config, t2s_config, p_dict_size, h_dict_size, s2t_data_dir, t2s_data_dir,
                  infer_s2t, infer_t2s, use_cuda=False):
    import logging
    for logger_ in [logging.getLogger(x) for x in logging.root.manager.loggerDict] + [logging.getLogger()]:
        logger_.setLevel(logging.WARNING)
    from utils.commons.hparams import hparams

    hparams.update(task_hparams)
    exp_name = hparams['exp_name']
    model_s2t, model_t2s = get_models(hparams, p_dict_size, h_dict_size, t2s_config, s2t_config)
    model_t2s.eval()
    model_s2t.eval()
    data_dir = hparams['binary_data_dir']
    if infer_t2s is not None:
        ds = BaseConcatDataset([
            UnsuperTTSBidirDataset(shuffle=False, data_dir=f"{data_dir}/{hparams['unsuper_ds_txt']}", prefix='train'),
        ])
    else:
        ds = BaseConcatDataset([
            UnsuperTTSBidirDataset(shuffle=False, data_dir=f"{data_dir}/{hparams['unsuper_ds_audio']}",
                                   prefix='train'),
        ])
    txt_token_encoder = get_txt_dict(hparams)
    lm = None
    if hparams['lm_load_ckpt'] != '' and hparams['lm_weight'] > 0:
        lm = TransformerDecoder(len(txt_token_encoder), **hparams['lm_model_config']['model_conf'])
        load_ckpt(lm, hparams['lm_load_ckpt'], silent=True)
        lm.eval()
        if use_cuda:
            lm.cuda()

    # ds_len = len(ds)
    ds_len = min(100000, len(ds))
    n_items_worker = ds_len // total_workers + 1
    i_s = worker_i * n_items_worker
    i_end = worker_i * n_items_worker + n_items_worker
    i_end = min(i_end, ds_len)
    x_list = list(range(i_s, i_end))
    random.shuffle(x_list)
    model_ckpt = torch.load(f'checkpoints/{exp_name}/model_latest.pt', map_location='cpu')
    model_t2s.load_state_dict(model_ckpt['t2s_state_dict'])
    model_t2s.eval()
    model_s2t.load_state_dict(model_ckpt['s2t_state_dict'])
    model_s2t.eval()
    if use_cuda:
        model_t2s.cuda()
        model_s2t.cuda()
    replace_if_better = hparams.get('replace_if_better', False) and os.path.exists(f'{t2s_data_dir}/{worker_i}')
    if replace_if_better:
        copy_file(f'{t2s_data_dir}/{worker_i}.data', f'{t2s_data_dir}/{worker_i}_tmp.data')
        last_t2s_ds = IndexedDataset2(f'{t2s_data_dir}/{worker_i}_tmp')
        last_scores = last_t2s_ds.meta['scores']
    s2t_ds_builder = IndexedDataset2Builder(f'{s2t_data_dir}/{worker_i}')
    t2s_ds_builder = IndexedDataset2Builder(f'{t2s_data_dir}/{worker_i}')
    s2t_ds_builder.meta['lens'] = []
    t2s_ds_builder.meta['lens'] = []
    t2s_ds_builder.meta['scores'] = []
    for i in x_list:
        item = ds[i]
        sample = ds.collater([item])
        if use_cuda:
            sample = move_to_cuda(sample)
        try:
            with torch.inference_mode():
                if infer_t2s is not None:
                    del item['mel']
                    del item['h_token']
                    spk_embed = ds.collater([ds[i + 1]])['spk_embed'] if hparams['use_spk_embed'] else None
                    mel_pred, out_t2s = infer_t2s(model_t2s, sample, hparams, None, spk_embed)
                    mel_pred = mel_pred[0].cpu()
                    item['mel'] = mel_pred
                    if hparams['use_spk_embed']:
                        item['spk_embed'] = spk_embed
                    # if hparams['clamp_len']:
                    #     item['mel'] = item['mel'][:800]
                    #     item['t_token'] = item['t_token'][:100]
                    s2t_ds_builder.add_item(item)
                    s2t_ds_builder.meta['lens'].append(len(item["t_token"]))
                else:  # unpaired speech
                    gt_t_token = item['t_token']
                    del item['t_token']
                    t_pred, mel2ph, ali = infer_s2t(
                        model_s2t, sample, hparams, lm=lm, lm_weight=hparams['lm_weight'], return_ali=True)
                    ali = ali.cpu()
                    mel2ph = mel2ph.cpu()
                    ali_cr = ali.gather(1, mel2ph[:, None, :]).mean().item()
                    t_pred = t_pred[0].cpu()
                    t_pred = F.pad(t_pred, [1, 0], value=2)
                    if ali_cr < hparams.get('filter_ali_cr', 0.1):
                        continue
                    if replace_if_better and ali_cr < last_scores[i]:
                        t2s_ds_builder.add_item(last_t2s_ds[i])
                        t2s_ds_builder.meta['lens'].append(last_t2s_ds.meta['len'][i])
                        t2s_ds_builder.meta['scores'].append(last_scores[i])
                        continue
                    if t_pred[-1] != 1:
                        t_pred = F.pad(t_pred, [0, 1], value=1)
                    item['t_token'] = t_pred
                    if mel2ph is not None:
                        item['mel2ph'] = mel2ph[0]
                    t2s_ds_builder.add_item(item)
                    t2s_ds_builder.meta['lens'].append(len(item["mel"]))
                    t2s_ds_builder.meta['scores'].append(ali_cr)
                    item['mel'] = None
                    item['index_raw'] = i
        except:
            traceback.print_exc()
    if replace_if_better:
        remove_file(f'{s2t_data_dir}/{worker_i}_tmp.data')
    t2s_ds_builder.finalize()
    s2t_ds_builder.finalize()
    print(f"| worker #{worker_i} finished.")


def merge_samples(samples_list):
    ids_shuffle = [list(range(s['nsamples'])) for s in samples_list]
    for ids in ids_shuffle:
        random.shuffle(ids)
    min_len = min([len(ids) for ids in ids_shuffle]) // len(samples_list)
    new_samples = {
        'mels': [],
        't_tokens': [],
        's_lengths': [],
        't_lengths': [],
    }
    for i in range(min_len):
        mels = []
        s_lengths = 0
        t_tokens = []
        t_lengths = 0
        for ids, samples in zip(ids_shuffle, samples_list):
            id_ij = ids[i]
            mels_ = samples['mels'][id_ij]
            mels_ = mels_[(mels_.abs().sum(-1) > 0)]
            t_tokens_ = samples['t_tokens'][id_ij]
            t_tokens_ = t_tokens_[t_tokens_ > 2]
            if mels_.shape[0] == 0 or len(t_tokens_) == 0:
                continue
            mels.append(mels_)
            t_tokens.append(t_tokens_)
            s_lengths = s_lengths + samples['s_lengths'][id_ij]
            assert len(t_tokens_) == samples['t_lengths'][id_ij] - 2
            t_lengths = t_lengths + samples['t_lengths'][id_ij] - 2
        new_samples['mels'].append(torch.cat(mels, 0))
        new_samples['s_lengths'].append(s_lengths)
        t_tokens = torch.cat(t_tokens, 0)
        t_tokens = F.pad(t_tokens, [1, 0], value=2)
        t_tokens = F.pad(t_tokens, [0, 1], value=1)
        t_lengths = t_lengths + 2
        new_samples['t_tokens'].append(t_tokens)
        new_samples['t_lengths'].append(t_lengths)
    if len(new_samples['mels']) == 0:
        return None
    new_samples['mels'] = collate_1d_or_2d(new_samples['mels'], 0)
    new_samples['t_tokens'] = collate_1d_or_2d(new_samples['t_tokens'], 0)
    new_samples['t_lengths'] = torch.LongTensor(new_samples['t_lengths'])
    new_samples['s_lengths'] = torch.LongTensor(new_samples['s_lengths'])
    new_samples['nsamples'] = min_len
    new_samples['lang_ids'] = samples_list[0]['lang_ids'][:min_len]
    if samples_list[0]['mels'].device.type == 'cuda':
        move_to_cuda(new_samples)
    return new_samples


def wenet_asr_recog(wav):
    from websocket import create_connection
    import json

    ws = create_connection("ws://WENET_IP:10086")
    ws.send(json.dumps({"signal": "start", "nbest": 1, "continuous_decoding": False}))
    wav = (wav * 32768).astype("int16")
    bytestream = wav.tobytes()
    ws.send_binary(bytestream)
    ws.send(json.dumps({"signal": "end"}))
    _ = ws.recv()
    message = ws.recv()
    message = json.loads(json.loads(message)['nbest'])
    sentence = message[0]['sentence']
    ws.close()
    return sentence


def azure_asr_recog(wav):
    try:
        import azure.cognitiveservices.speech as speechsdk
    except ImportError:
        print("""
        Importing the Speech SDK for Python failed.
        Refer to
        https://docs.microsoft.com/azure/cognitive-services/speech-service/quickstart-python for
        installation instructions.
        """)
        import sys
        sys.exit(1)
    subscription_key = "YOUR_KEY"
    location = "eastus"
    """performs one-shot speech recognition with input from an audio file"""
    speech_config = speechsdk.SpeechConfig(subscription=subscription_key, region=location)
    tmp_fn = f'/tmp/{uuid.uuid4()}.wav'
    save_wav(wav, tmp_fn, 16000)
    audio_config = speechsdk.audio.AudioConfig(filename=tmp_fn)
    speech_recognizer = speechsdk.SpeechRecognizer(
        speech_config=speech_config, language="id-ID", audio_config=audio_config)
    result = speech_recognizer.recognize_once()
    if result.reason == speechsdk.ResultReason.RecognizedSpeech:
        result = result.text
    else:
        result = ''
    remove_file(tmp_fn)
    return result


class UnsuperTTSTask(SpeechBaseTask):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.rich_langs = hparams['rich_langs']
        self.unsuper_training = hparams['training_mode'] == 'unsuper'
        self.unsup_train_dataloaders = {'s2t': None, 't2s': None}
        self.sup_train_dataloaders = {'s2t': None, 't2s': None}
        if self.unsuper_training:
            self.test_langs = [hparams['unsuper_test_ds']]
        else:
            self.test_langs = hparams['rich_langs']
        data_dir = hparams['binary_data_dir']
        self.txt_token_encoder = get_txt_dict(hparams)
        self.hb_token_encoder = build_token_encoder(f'{data_dir}/hubert_set.json')
        self.txt_dict_size = len(self.txt_token_encoder)
        self.hb_dict_size = len(self.hb_token_encoder)
        self.lang2id = {k: v['lang_id'] for k, v in hparams['langs'].items()}
        self.dataset_cls = UnsuperTTSBidirDataset
        self.cur_stage = 'both'

    def train_dataloader(self):
        return torch.utils.data.DataLoader(TensorDataset(torch.cat([torch.FloatTensor(1)] * 10000)))

    def get_next_train_sample(self, direction):
        dl = self.sup_train_dataloaders[direction]
        try:
            return next(dl)
        except Exception as e:
            data_dir = hparams['binary_data_dir']
            if direction == 't2s':
                train_dataset = BaseConcatDataset([
                    self.dataset_cls(
                        size_key='mels',
                        prefix='train', shuffle=True, data_dir=f"{data_dir}/{lang_name}")
                    for lang_name in self.rich_langs])
            else:
                train_dataset = BaseConcatDataset([
                    self.dataset_cls(
                        size_key='chr_tokens',
                        prefix='train', shuffle=True, data_dir=f"{data_dir}/{lang_name}")
                    for lang_name in self.rich_langs])
            max_tokens = hparams[f'{direction}_max_tokens']
            dl = self.build_dataloader(train_dataset, True, max_tokens, self.max_sentences)
            dl = self.sup_train_dataloaders[direction] = iter(dl)
            return next(dl)

    def val_dataloader(self):
        data_dir = hparams['binary_data_dir']
        val_dataset = BaseConcatDataset([
            self.dataset_cls(
                prefix='test', shuffle=False, data_dir=f"{data_dir}/{lang_name}") for lang_name in self.test_langs])
        return self.build_dataloader(val_dataset, False, self.max_valid_tokens, self.max_valid_sentences)

    def test_dataloader(self):
        return self.val_dataloader()

    @staticmethod
    def get_models(hparams, p_dict_size, h_dict_size, t2s_config, s2t_config):
        t2s_out_dims = 80
        model_t2s = SpeechGenerator(
            p_dict_size, t2s_config, use_dur=True, out_dims=t2s_out_dims, vae_dims=80)
        st2_in_dims = 0
        model_s2t = RNNTranslator(st2_in_dims, p_dict_size, s2t_config)
        return model_s2t, model_t2s

    def build_disc_model(self):
        self.mel_disc = build_disc(hparams, hparams['audio_num_mel_bins'])
        print_arch(self.mel_disc, model_name='Mel Disc')

    def build_optimizer(self, model):
        optimizer_s2t = torch.optim.AdamW(
            self.model_s2t.parameters(),
            lr=hparams['lr'],
            betas=(hparams['optimizer_adam_beta1'], hparams['optimizer_adam_beta2']),
            weight_decay=hparams['weight_decay'])
        optimizer_t2s = torch.optim.AdamW(
            self.model_t2s.parameters(),
            lr=hparams['lr'],
            betas=(hparams['optimizer_adam_beta1'], hparams['optimizer_adam_beta2']),
            weight_decay=hparams['weight_decay'])
        optimizer_disc_t2s = torch.optim.AdamW(
            self.mel_disc.parameters(),
            lr=hparams['disc_lr'],
            betas=(hparams['optimizer_adam_beta1'], hparams['optimizer_adam_beta2']),
            **hparams["discriminator_optimizer_params"])
        return [optimizer_s2t, optimizer_t2s, optimizer_disc_t2s]

    @staticmethod
    def infer_t2s(model_t2s, sample, hparams, token2mel_model=None, spk_embed=None):
        out = model_t2s(sample['t_tokens'], sample['lang_ids'], spk_embed=spk_embed, infer=True)
        mel_pred = out['mel_out']
        return mel_pred, out

    @staticmethod
    def infer_s2t(model_s2t, sample, hparams, txt_encoder=None, lm=None, lm_weight=0.5, return_ali=False):
        s2t_inp = sample['mels']
        t_pred, _, ali = model_s2t.inference(sample['lang_ids'], s2t_inp, lm, lm_weight)[:3]
        s_lengths = sample['s_lengths']
        t_lengths = (t_pred > 0).long().sum(-1) + 1
        if ali is None:
            t_inp = F.pad(t_pred, [1, 0], value=2)
            _, _, ali = model_s2t(sample['lang_ids'], s2t_inp, s_lengths, t_inp, t_lengths)[:3]
        _, mel2ph, _ = get_best_alignments(ali.clone(), t_lengths, s_lengths)
        if return_ali:
            return t_pred, mel2ph, ali
        return t_pred, mel2ph

    def build_model(self):
        self.t2s_config = hparams['t2s_model_config']
        self.s2t_config = hparams['s2t_model_config']
        self.model_s2t, self.model_t2s = self.get_models(
            hparams, self.txt_dict_size, self.hb_dict_size, self.t2s_config, self.s2t_config)
        self.build_disc_model()
        if hparams['s2t_load_ckpt'] != '':
            load_ckpt(self.model_s2t, hparams['s2t_load_ckpt'], 'model_s2t')
        if hparams['t2s_load_ckpt'] != '':
            load_ckpt(self.model_t2s, hparams['t2s_load_ckpt'], 'model_t2s')
            load_ckpt(self.mel_disc, hparams['t2s_load_ckpt'], 'mel_disc')
        if 'lang_embed_init' in hparams and self.unsuper_training and hparams['lang_embed_init'] != 'none':
            unsuper_lng_id = self.lang2id[hparams['unsuper_ds']]
            if hparams['lang_embed_init'] == 'mean':
                self.model_t2s.lang_embed.weight.data[unsuper_lng_id] = \
                    torch.stack([self.model_t2s.lang_embed.weight.data[self.lang2id[i]]
                                 for i in self.rich_langs], 0).mean(0)
                if self.s2t_config['use_lang_embed']:
                    self.model_s2t.lang_embed.weight.data[unsuper_lng_id] = \
                        torch.stack([self.model_s2t.lang_embed.weight.data[self.lang2id[i]]
                                     for i in self.rich_langs], 0).mean(0)
            else:
                init_lng_id = self.lang2id[hparams['lang_embed_init']]
                self.model_t2s.lang_embed.weight.data[unsuper_lng_id] = \
                    self.model_t2s.lang_embed.weight.data[init_lng_id]
                if self.s2t_config['use_lang_embed']:
                    self.model_s2t.lang_embed.weight.data[unsuper_lng_id] = \
                        self.model_s2t.lang_embed.weight.data[init_lng_id]
        if not hparams['t2s_only']:
            print_arch(self.model_s2t, 's2t')
        if not hparams['s2t_only']:
            print_arch(self.model_t2s, 't2s')

    def get_next_unpaired_sample(self, direction):
        data_dir = f'{self.trainer.work_dir}/bt_data/{direction}_data'
        raw_data_path = f"{hparams['binary_data_dir']}/{hparams['unsuper_ds']}/train"
        dl = self.unsup_train_dataloaders[direction]
        try:
            return next(dl)
        except Exception as e:
            data_paths = sorted(glob.glob(f'{data_dir}/*.data'))
            ds = BaseConcatDataset([
                UnsuperTTSBidirBTDataset(data_path[:-5], raw_data_path, shuffle=True)
                for data_path in data_paths])
            max_tokens = hparams[f'{direction}_max_tokens']
            dl = self.build_dataloader(ds, True, max_tokens, self.max_sentences)
            dl = self.unsup_train_dataloaders[direction] = iter(dl)
            print(f"| Rebuild {direction} dataset. Number of items: {len(ds)}.")
            return next(dl)

    def build_scheduler(self, optimizer):
        if hparams['scheduler'] == 'warmup':
            scheduler_s2t = WarmupSchedule(optimizer[0], hparams['lr'], hparams['warmup_updates'])
            scheduler_t2s = WarmupSchedule(optimizer[1], hparams['lr'], hparams['warmup_updates'])
            return [scheduler_s2t, scheduler_t2s]
        else:
            return None

    def on_train_start(self):
        if self.unsuper_training:
            self.infer_runners = []
            self.n_infer_runners = hparams['n_infer_runners']
            self.exp_name = hparams['exp_name']
            self.s2t_data_dir = f'checkpoints/{self.exp_name}/bt_data/s2t_data'
            self.t2s_data_dir = f'checkpoints/{self.exp_name}/bt_data/t2s_data'
            os.makedirs(self.s2t_data_dir, exist_ok=True)
            os.makedirs(self.t2s_data_dir, exist_ok=True)

    def start_infer_runners(self, direction, use_gpu=False):
        ray_init()
        mem_GB_required = 4
        self.inference_job_remote = ray.remote(
            num_cpus=1, memory=mem_GB_required * 1024 * 1024 * 1024, resources={'worker_flags': 1})(inference_job)
        self.sync_model_states()
        self.unsup_train_dataloaders = {'s2t': None, 't2s': None}
        for i in range(self.n_infer_runners):
            print(f"| Start IDLE runner: #{i}.")
            task_i = self.inference_job_remote.remote(
                i, self.n_infer_runners, self.get_models, hparams,
                hparams['s2t_model_config'], hparams['t2s_model_config'],
                self.txt_dict_size, self.hb_dict_size,
                self.s2t_data_dir, self.t2s_data_dir,
                self.infer_s2t if direction == 't2s' else None,
                self.infer_t2s if direction == 's2t' else None)
            self.infer_runners.append(task_i)
        pbar = tqdm(total=self.n_infer_runners, desc='inferring on CPU')
        while len(self.infer_runners) > 0:
            ready_runners, self.infer_runners = ray.wait(self.infer_runners)
            pbar.update(len(ready_runners))
        ray_shutdown()

    def sync_model_states(self):
        saved_dict = {}
        saved_dict['t2s_state_dict'] = {k: v.cpu() for k, v in self.model_t2s.state_dict().items()}
        saved_dict['s2t_state_dict'] = {k: v.cpu() for k, v in self.model_s2t.state_dict().items()}
        ckpt_path = f'{self.trainer.work_dir}/model_latest.pt'
        print("| save model states")
        torch.save(saved_dict, ckpt_path)

    def add_LSGAN_losses(self, p, target, ret, name='A'):
        if isinstance(p, torch.Tensor):
            ret[name] = F.mse_loss(p, p.new_ones(p.size()) * target)
        else:
            for i, p_i in enumerate(p):
                ret[f'{name}{i}'] = F.mse_loss(p_i, p_i.new_ones(p_i.size()) * target)

    def run_model(self, model, direction, sample=None, return_outputs=False):
        no_text = False
        if self.unsuper_training and self.training:
            use_sup_data_prob = hparams.get('use_sup_data_prob', 0)
            if (direction not in hparams['train_with_all_langs']) or random.random() > use_sup_data_prob:
                sample = self.get_next_unpaired_sample(direction)
                if random.random() < hparams.get('concat_prob', 0) * (self.global_step / 100000) and (
                        not hparams['use_spk_embed'] or direction == 's2t'
                ):
                    sample_2 = self.get_next_unpaired_sample(direction)
                    if sample_2['nsamples'] > 1 and sample['nsamples'] > 1:
                        sample = merge_samples([sample, sample_2])
        if sample is None:
            sample = self.get_next_train_sample(direction)
        sample = move_to_cuda(sample)
        if direction == 's2t':
            s2t_inp = sample['mels']
            output, gate_outputs, alignments = model(
                sample['lang_ids'], s2t_inp, sample['s_lengths'], sample['t_tokens'], sample['t_lengths'])
            xe = F.cross_entropy(output.transpose(1, 2), sample['t_tokens'][:, 1:], ignore_index=0)
            losses = {'xe': xe}
            ali = alignments.detach()
            _, mel2ph, _ = get_best_alignments(ali.clone(), sample['t_lengths'], sample['s_lengths'])
            self.mel2ph = mel2ph.data
            self.t_tokens = sample['t_tokens'].data
            ali = ali.cpu().numpy()
            output = {'token_outs': output, 'mel2ph': mel2ph}
        else:
            losses = {}
            ali = None
            mel2ph = None
            if 'mel2ph' in sample:
                mel2ph = sample['mel2ph']
            if mel2ph is None:
                self.model_s2t.eval()
                with torch.inference_mode():
                    s2t_inp = sample['mels']
                    _, _, alignments = self.model_s2t(
                        sample['lang_ids'], s2t_inp, sample['s_lengths'], sample['t_tokens'], sample['t_lengths'])[:3]
                    ali = alignments.detach()
                    _, mel2ph, _ = get_best_alignments(ali.clone(), sample['t_lengths'], sample['s_lengths'])
                    ali = ali.detach().cpu().numpy()
                    mel2ph = mel2ph.data.to(sample['t_tokens'].device)
                if self.training:
                    self.model_s2t.train()
            mel2ph = (mel2ph + 1) * (sample['mels'].abs().sum(-1) > 0).long()
            spk_embed = sample.get('spk_embed')
            if no_text:
                sample['t_tokens'] = None
            output = self.model_t2s(sample['t_tokens'], sample['lang_ids'], infer=False,
                                    tgt_mels=sample['mels'], mel2ph=mel2ph, spk_embed=spk_embed,
                                    global_step=self.global_step)
            self.mel_out = None
            losses['kl_v'] = output['kl'].detach()
            losses_kl = output['kl']
            losses_kl = torch.clamp(losses_kl, min=hparams['kl_min'])
            losses_kl = min(self.global_step / hparams['kl_start_steps'], 1) * losses_kl
            losses_kl = losses_kl * hparams['lambda_kl']
            losses['kl'] = losses_kl
            self.mel_g = sample['mels'] = clip_seq_to_multiple(sample['mels'], hparams['frames_multiple'])
            self.add_mel_loss(output['mel_out'], sample['mels'], losses)
            if not no_text:
                self.add_dur_loss(output['dur'], output['mel2ph'], sample['t_tokens'], losses)
            self.mel_out = output['mel_out']
        if not return_outputs:
            return losses, ali
        else:
            return losses, ali, output

    def _training_step(self, _, batch_idx, optm_idx):
        losses = {}
        disc_start = self.global_step >= hparams["disc_start_steps"] and hparams['lambda_mel_adv'] > 0
        cdsteps = hparams.get('change_dir_steps', 100)
        if hparams['run_both_dir']:
            skip_s2t = skip_t2s = False
            self.cur_stage = 'both'
        else:
            skip_t2s = not hparams['t2s_only'] and (
                    hparams['s2t_only'] or
                    self.unsuper_training and (self.global_step // cdsteps)
                    % hparams['changing_dir_period'] in hparams['s2t_slots'] or
                    not self.unsuper_training and self.global_step < hparams['t2s_start_steps']
            )
            skip_s2t = not skip_t2s
            self.cur_stage = 't2s' if skip_s2t else 's2t'
        if self.unsuper_training and (self.global_step % cdsteps == 0 and optm_idx == 0 or hparams['rebuild_bt_data']):
            self.start_infer_runners('s2t' if skip_t2s else 't2s')
        if optm_idx == 0:  # s2t
            if skip_s2t:
                return
            losses_s2t, ali_s2t = self.run_model(self.model_s2t, 's2t')
            losses.update(losses_s2t)
        if optm_idx == 1:  # t2s
            if skip_t2s:
                return
            losses_t2s, ali_t2s = self.run_model(self.model_t2s, 't2s')
            losses.update(losses_t2s)
            if disc_start:
                mel_p = self.mel_out
                o_ = self.mel_disc(mel_p)
                p_, h_p_, start_frames = o_['y'], o_.get('h'), o_.get('start_frames')
                self.add_LSGAN_losses(p_, 1, losses, 'A')
                self.mel_out = self.mel_out.detach()
        if optm_idx == 2:
            if not disc_start or skip_t2s:
                return
            mel_p = self.mel_out
            B = mel_p.shape[0]
            o = self.mel_disc(torch.cat([self.mel_g, mel_p], 0))
            p = [y[:B] for y in o['y']]
            p_ = [y[B:] for y in o['y']]
            self.add_LSGAN_losses(p, 1, losses, 'R')
            self.add_LSGAN_losses(p_, 0, losses, 'F')
        total_losses = sum(losses.values())
        return total_losses, losses

    def add_dur_loss(self, dur_pred, mel2ph, txt_tokens, losses=None):
        """

        :param dur_pred: [B, T], float, log scale
        :param mel2ph: [B, T]
        :param txt_tokens: [B, T]
        :param losses:
        :return:
        """
        B, T = txt_tokens.shape
        nonpadding = (txt_tokens != 0).float()
        dur_gt = mel2token_to_dur(mel2ph, T).float() * nonpadding
        losses['pdur'] = F.mse_loss((dur_pred + 1).log(), (dur_gt + 1).log(), reduction='none')
        losses['pdur'] = (losses['pdur'] * nonpadding).sum() / nonpadding.sum()
        losses['pdur'] = losses['pdur'] * hparams['lambda_ph_dur']
