import torch

from modules.commons.conv import TextConvEncoder
from modules.voice_conversion.vc_modules import ConvStacks
from torch import nn
from modules.commons.layers import Embedding
from modules.commons.rel_transformer import RelTransformerEncoder
from modules.tts.commons.align_ops import clip_seq_to_multiple, expand_states
from modules.tts.portaspeech.fvae import FVAE
from modules.tts.portaspeech.portaspeech import PortaSpeech
import torch.nn.functional as F


class SpeechGenerator(PortaSpeech):
    def __init__(self, txt_dict_size, hparams, use_dur=False, vae_dims=None, out_dims=None,
                 content_discrete=True, encoder_type='fft'):
        super().__init__(txt_dict_size, 0, hparams, out_dims)
        if vae_dims is not None:
            self.fvae = FVAE(
                c_in=vae_dims,
                c_out=out_dims,
                hidden_size=hparams['fvae_enc_dec_hidden'], c_latent=hparams['latent_size'],
                kernel_size=hparams['fvae_kernel_size'],
                enc_n_layers=hparams['fvae_enc_n_layers'],
                dec_n_layers=hparams['fvae_dec_n_layers'],
                c_cond=self.hidden_size,
                use_prior_flow=hparams['use_prior_flow'],
                flow_hidden=hparams['prior_flow_hidden'],
                flow_kernel_size=hparams['prior_flow_kernel_size'],
                flow_n_steps=hparams['prior_flow_n_blocks'],
                strides=hparams['fvae_strides'],
                encoder_type=hparams['fvae_encoder_type'],
                decoder_type=hparams['fvae_decoder_type'],
                pflow_nn_type=hparams['prior_flow_nn_type'],
            )
        self.hp = hp = hparams
        c_content_embed = hp['c_content_embed']
        self.content_discrete = content_discrete
        if content_discrete:
            self.emb = Embedding(txt_dict_size, c_content_embed, padding_idx=0)
        else:
            self.emb = nn.Linear(txt_dict_size, c_content_embed)
        self.embed_linear = nn.Linear(c_content_embed, hp['hidden_size'])
        if encoder_type == 'fft':
            self.encoder = RelTransformerEncoder(
                0, hp['hidden_size'], hp['hidden_size'],
                hp['ffn_hidden_size'], hp['num_heads'], hp['enc_layers'],
                hp['enc_ffn_kernel_size'], hp['dropout'], prenet=hp['enc_prenet'], pre_ln=hp['enc_pre_ln'])
        else:
            self.encoder = TextConvEncoder(0, hp['hidden_size'], hp['hidden_size'],
                                           hp['enc_dilations'], hp['enc_kernel_size'],
                                           layers_in_block=hp['layers_in_block'],
                                           norm_type=hp['enc_dec_norm'],
                                           post_net_kernel=hp.get('enc_post_net_kernel', 3))
        if self.hp['use_lang_embed']:
            self.lang_embed = Embedding(100, hp['hidden_size'])
        self.use_dur = use_dur
        if not use_dur:
            del self.dur_predictor
            del self.length_regulator

    def forward(self, txt_tokens, lang_ids=None,
                spk_embed=None, spk_id=None, infer=False, tgt_mels=None, global_step=None,
                mel2ph=None, *args, **kwargs):
        ret = {}
        style_embed = self.forward_style_embed(spk_embed, spk_id)
        lang_embed = self.lang_embed(lang_ids)[:, None, :] if self.hp['use_lang_embed'] else 0
        if txt_tokens is not None:
            x, tgt_nonpadding = self.run_text_encoder(txt_tokens, style_embed, lang_embed, mel2ph, ret)
            x = x * tgt_nonpadding
            ret['nonpadding'] = tgt_nonpadding
            ret['decoder_inp'] = x
        else:
            B, T = mel2ph.shape
            x = torch.zeros([B, T, self.hp['hidden_size']]).to(mel2ph.device)
            tgt_nonpadding = tgt_mels.abs().sum(-1) > 0
            tgt_nonpadding = tgt_nonpadding.float()[:, :, None]
        ret['mel_out_fvae'] = ret['mel_out'] = self.run_decoder(x, tgt_nonpadding, ret, infer, tgt_mels, global_step)
        return ret

    def run_text_encoder(self, txt_tokens, style_embed, lang_embed, mel2ph, ret):
        if self.content_discrete:
            src_nonpadding = (txt_tokens > 0).float()[:, :, None]
        else:
            src_nonpadding = (txt_tokens.abs().sum(-1) > 0).float()[:, :, None]
        ph_encoder_in = self.emb(txt_tokens)
        ph_encoder_in = self.embed_linear(ph_encoder_in)
        frames_multiple = self.hp['frames_multiple']
        if self.use_dur:
            ph_encoder_in = (ph_encoder_in + style_embed + lang_embed) * src_nonpadding
            x = self.encoder(ph_encoder_in)
            mel2ph = self.forward_dur(x, mel2ph, ret)
            ret['mel2ph'] = mel2ph
            x = expand_states(x, mel2ph)
            nonpadding = expand_states(src_nonpadding, mel2ph)
            x = clip_seq_to_multiple(x, frames_multiple)
            nonpadding = clip_seq_to_multiple(nonpadding, frames_multiple)
        else:
            ph_encoder_in = clip_seq_to_multiple(ph_encoder_in, frames_multiple)
            src_nonpadding = clip_seq_to_multiple(src_nonpadding, frames_multiple)
            x = (self.encoder(ph_encoder_in) + style_embed + lang_embed) * src_nonpadding
            nonpadding = src_nonpadding
        return x, nonpadding


class SpeechGeneratorVC(SpeechGenerator):
    def __init__(self, txt_dict_size, hparams, use_dur=False, vae_dims=None, out_dims=None,
                 content_discrete=True, encoder_type='fft', rescale=1.0):
        super().__init__(txt_dict_size, hparams, use_dur, vae_dims, out_dims,
                         content_discrete, encoder_type)
        self.rescale = rescale
        # mel content encoder
        if hparams['use_energy']:
            self.energy_embed = Embedding(256, self.hidden_size, 0)
        # pitch encoder
        if hparams['use_pitch']:
            self.pitch_embed = Embedding(300, self.hidden_size, 0)
            self.pitch_encoder = ConvStacks(
                idim=self.hidden_size, n_chans=self.hidden_size, odim=self.hidden_size, n_layers=3)

    def forward(self, txt_tokens, lang_ids=None,
                spk_embed=None, spk_id=None, pitch=None, energy=None,
                infer=False, tgt_mels=None, global_step=None, mel2ph=None, *args, **kwargs):
        ret = {}
        style_embed = self.forward_style_embed(spk_embed, spk_id)
        lang_embed = self.lang_embed(lang_ids)[:, None, :] if self.hp['use_lang_embed'] else 0
        if txt_tokens is not None:
            x, tgt_nonpadding = self.run_text_encoder(txt_tokens, style_embed, lang_embed, mel2ph, ret)
            x = x * tgt_nonpadding
            ret['nonpadding'] = tgt_nonpadding
        else:
            B, T = mel2ph.shape
            x = torch.zeros([B, T, self.hp['hidden_size']]).to(mel2ph.device)
            tgt_nonpadding = tgt_mels.abs().sum(-1) > 0
            tgt_nonpadding = tgt_nonpadding.float()[:, :, None]
        if self.hparams['use_energy']:
            energy = torch.clamp(energy * 256 // 4, max=255).long()
            h_energy = self.energy_embed(energy)
            x = x + h_energy
        if self.hparams['use_pitch']:
            h_pitch = self.pitch_encoder(self.pitch_embed(pitch))
            x = x + h_pitch
        x = x * tgt_nonpadding
        if self.rescale != 1.0:
            x = F.interpolate(x.transpose(1, 2), scale_factor=self.rescale, mode='linear').transpose(1, 2)
            tgt_nonpadding = F.interpolate(
                tgt_nonpadding.transpose(1, 2), scale_factor=self.rescale, mode='nearest').transpose(1, 2)
            if tgt_mels is not None:
                tgt_mels = tgt_mels[:, :x.shape[1]]
        frames_multiple = self.hp['frames_multiple']
        x = clip_seq_to_multiple(x, frames_multiple)
        tgt_nonpadding = clip_seq_to_multiple(tgt_nonpadding, frames_multiple)
        ret['decoder_inp'] = x
        ret['mel_out'] = self.run_decoder(x, tgt_nonpadding, ret, infer, tgt_mels, global_step)
        return ret

    def run_text_encoder(self, txt_tokens, style_embed, lang_embed, mel2ph, ret):
        if self.content_discrete:
            src_nonpadding = (txt_tokens > 0).float()[:, :, None]
        else:
            src_nonpadding = (txt_tokens.abs().sum(-1) > 0).float()[:, :, None]
        ph_encoder_in = self.emb(txt_tokens)
        ph_encoder_in = self.embed_linear(ph_encoder_in)
        if self.use_dur:
            ph_encoder_in = (ph_encoder_in + style_embed + lang_embed) * src_nonpadding
            x = self.encoder(ph_encoder_in)
            mel2ph = self.forward_dur(x, mel2ph, ret)
            ret['mel2ph'] = mel2ph
            x = expand_states(x, mel2ph)
            nonpadding = expand_states(src_nonpadding, mel2ph)
        else:
            x = (self.encoder(ph_encoder_in) + style_embed + lang_embed) * src_nonpadding
            nonpadding = src_nonpadding
        return x, nonpadding
