from modules.commons.common_layers import *
from modules.commons.common_layers import Embedding
from modules.fastspeech.conformer.conformer import ConformerDecoder, ConformerEncoder
from modules.fastspeech.fast_tacotron import TacotronEncoder, DecoderRNN, Tacotron2Encoder
from modules.fastspeech.speedy_speech.speedy_speech import ConvBlocks
from modules.fastspeech.tts_modules import FastspeechDecoder, PitchPredictor, \
    EnergyPredictor,  FastspeechEncoder, EmbeddingPlusConv

from modules.fastspeech.wavenet_decoder import WN
from modules.fastspeech.videotts.video_encoder import EncoderLip2Wav, Conv3dTransformerEncoder, ResNetTransformerEncoder
from modules.fastspeech.videotts.attention import GravesAttention, AdditiveAttention, MultiHeadedAttention, LSA
from utils.tts_utils import make_non_pad_mask, make_pad_mask, get_diagonal_mask
from utils.hparams import hparams
from utils.text_process import ctc_symbols
from utils.pitch_utils import f0_to_coarse, denorm_f0
from modules.fastspeech.videotts.resnet50_ft_dag import resnet50_ft_dag, Img2SpkNet
from modules.fastspeech.videotts.mutual_information import MIEsitmator
from utils.monotonic_align_utils import maximum_path



FS_ENCODERS = {
    'fft': lambda hp, embed_tokens, d: FastspeechEncoder(
        embed_tokens, hp['hidden_size'], hp['enc_layers'], hp['enc_ffn_kernel_size'],
        num_heads=hp['num_heads']),
    'tacotron': lambda hp, embed_tokens, d: TacotronEncoder(
        hp['hidden_size'], len(d), hp['hidden_size'],
        K=hp['encoder_K'], num_highways=4, dropout=hp['dropout']),
    'tacotron2': lambda hp, embed_tokens, d: Tacotron2Encoder(len(d), hp['hidden_size']),
    'conformer': lambda hp, embed_tokens, d: ConformerEncoder(embed_tokens, len(d)),
}

FS_DECODERS = {
    'fft': lambda hp: FastspeechDecoder(
        hp['hidden_size'], hp['dec_layers'], hp['dec_ffn_kernel_size'], hp['num_heads']),
    'rnn': lambda hp: DecoderRNN(hp['hidden_size'], hp['decoder_rnn_dim'], hp['dropout']),
    'conv': lambda hp: ConvBlocks(hp['hidden_size'], hp['hidden_size'], hp['dec_dilations']),
    'wn': lambda hp: WN(hp['hidden_size'], hp['hidden_size'], kernel_size=3, n_layers=hp['dec_layers']),
    'conformer': lambda hp: ConformerDecoder(hp['hidden_size']),
}

TEXT_VIDEO_ATT = {
    'gmm': lambda hp: GravesAttention(**hp["gmm_att_params"]),
    'add': lambda hp: AdditiveAttention(**hp["add_att_params"]),
    'scaled_dot': lambda hp: MultiHeadedAttention(**hp["scaled_dot_att_params"]),
    'torch_dot': lambda hp: MultiheadAttention(**hp["torch_dot_att_params"]),
    'lsa': lambda hp: LSA(**hp["lsa_att_params"]),
}

VIDEO_ENCODERS = {
    'conv3d_rnn': lambda hp: EncoderLip2Wav(**hp["conv3d_rnn_params"]),
    'conv3d_fft': lambda hp: Conv3dTransformerEncoder(**hp["conv3d_fft_params"]),
    'conv2d_fft': lambda hp: ResNetTransformerEncoder(**hp["conv2d_fft_params"]),
}


class VideoTts(nn.Module):
    def __init__(self, dictionary=100, out_dims=None):
        super().__init__()
        self.dictionary = dictionary
        self.padding_idx = 0
        self.enc_layers = hparams['enc_layers']
        self.dec_layers = hparams['dec_layers']
        self.hidden_size = hparams['hidden_size']
        self.encoder_embed_tokens = self.build_embedding(self.dictionary, self.hidden_size)
        self.encoder = FS_ENCODERS[hparams['encoder_type']](hparams, self.encoder_embed_tokens, self.dictionary)
        if not hparams['use_pitch_as_query']:
            self.video_encoder = VIDEO_ENCODERS[hparams['video_encoder_type']](hparams)
        self.decoder = FS_DECODERS[hparams['decoder_type']](hparams)
        self.text_video_att = TEXT_VIDEO_ATT[hparams["text_video_att_type"]](hparams)
        self.out_dims = out_dims
        if out_dims is None:
            self.out_dims = hparams['audio_num_mel_bins']
        self.mel_out = Linear(self.hidden_size, self.out_dims, bias=True)

        if hparams['use_img_spk_embed']:
            self.vgg_encoder = resnet50_ft_dag(weights_path=hparams['face_recognizer_weight_path'])
            self.vgg_embed_proj = Img2SpkNet(hparams['hidden_size'], drop_rate=0.2)

        if hparams['use_mutual_information']:
            vocab_size = len(ctc_symbols)
            decoder_dim = hparams['hidden_size']
            self.mi = MIEsitmator(vocab_size, decoder_dim, decoder_dim, dropout=0.5)

        if hparams['use_spk_id']:
            self.spk_embed_proj = Embedding(hparams['num_spk'], self.hidden_size)
        elif hparams['use_spk_embed']:
            self.spk_embed_proj = Linear(256, self.hidden_size, bias=True)
        predictor_hidden = hparams['predictor_hidden'] if hparams['predictor_hidden'] > 0 else self.hidden_size
        if hparams['use_pitch_as_query']:
            if hparams['pitch_embed_type'] == 'embed_conv':
                self.pitch_embed = EmbeddingPlusConv(300, self.hidden_size, use_pos_embed=hparams['use_pitch_pos_embed'])
            elif hparams['pitch_embed_type'] == 'embed_selfattn':
                token_embed = Embedding(300, self.hidden_size, self.padding_idx)
                self.pitch_embed = FastspeechEncoder(token_embed, num_layers=2,
                                                     use_pos_embed=hparams['use_pitch_pos_embed'])
        if hparams['use_pitch_embed']:
            self.pitch_embed = Embedding(300, self.hidden_size, self.padding_idx)
            self.pitch_predictor = PitchPredictor(
                self.hidden_size,
                n_chans=predictor_hidden,
                n_layers=hparams['predictor_layers'],
                dropout_rate=hparams['predictor_dropout'],
                odim=2 if hparams['pitch_type'] == 'frame' else 1,
                padding=hparams['ffn_padding'], kernel_size=hparams['predictor_kernel'])
        if hparams['use_energy_embed']:
            self.energy_embed = Embedding(256, self.hidden_size, self.padding_idx)
            self.energy_predictor = EnergyPredictor(
                self.hidden_size,
                n_chans=predictor_hidden,
                n_layers=hparams['predictor_layers'],
                dropout_rate=hparams['predictor_dropout'], odim=1,
                padding=hparams['ffn_padding'], kernel_size=hparams['predictor_kernel'])
        if hparams['use_align_branch']:
            self.align_branch = PitchPredictor(hparams['hidden_size'], n_layers=2, n_chans=384,
                                               odim=hparams['audio_num_mel_bins'])

    def build_embedding(self, dictionary_len, embed_dim):
        num_embeddings = dictionary_len
        emb = Embedding(num_embeddings, embed_dim, self.padding_idx)
        return emb

    def forward(self, txt_tokens, video, txt_lens, vid_lens, pitch=None, pitch_lens=None, mel2ph=None, spk_embed=None,
                f0=None, uv=None, energy=None, skip_decoder=False, spk_img=None, infer=False, **kwargs):
        ret = {}
        encoder_out = self.encoder(txt_tokens)  # [B, T, C]
        src_nonpadding = (txt_tokens > 0).float()[:, :, None]

        # encoder_out_dur denotes encoder outputs for duration predictor
        # in speech adaptation, duration predictor use old speaker embedding
        if hparams['use_spk_embed'] or hparams['use_spk_id']:
            spk_embed = self.spk_embed_proj(spk_embed)[:, None, :]
        else:
            spk_embed = 0

        if hparams['use_img_spk_embed']:
            assert spk_img is not None
            spk_img = spk_img.permute(0, 3, 1, 2)  # (B, H, W, 3) -> (B, 3, H, W)
            # input image RGB, shape (B, H, W, 3)
            if hparams['fixed_face_rec']:
                # face rec is fixed
                self.vgg_encoder.eval()
                with torch.no_grad():
                    vgg_face_embed = self.vgg_encoder(spk_img).detach()
            else:
                # pretrained face rec  finetune
                vgg_face_embed = self.vgg_encoder(spk_img)
            img_spk_embed = self.vgg_embed_proj(vgg_face_embed)[:, None, :]  # (B, 1, H)
        else:
            img_spk_embed = 0
        # add dur
        dur_inp = (encoder_out + spk_embed) * src_nonpadding

        if not hparams['use_pitch_as_query']:
            assert video is not None
            query = self.video_encoder(video, vid_lens)  # [B, T_vid, H_vid], H_vid == hparams['hidden_size']
            assert query.shape[-1] == hparams['hidden_size']
            query_lens = vid_lens
            repeat_num = hparams["vid_mel_repeat_num"]
            ret['video_hidden'] = query
        else:
            query = self.pitch_embed(pitch)  # [B, T, H]
            query_lens = pitch_lens
            repeat_num = 1

        decoder_inp, video_text_att, tgt_nonpadding, diagonal_loss, diagonal_mask, entropy_loss = \
            self.get_video_text_att(dur_inp, query, txt_lens, query_lens, repeat_num)

        ret['diag_loss'] = diagonal_loss
        ret['entropy_loss'] = entropy_loss
        ret['attn_map'] = video_text_att
        ret['diagonal_mask'] = diagonal_mask
        decoder_inp_origin = decoder_inp  # (B, T_dec, H)
        # tgt_nonpadding = (mel2ph > 0).float()[:, :, None]

        # add pitch and energy embed
        pitch_inp = (decoder_inp_origin + spk_embed + img_spk_embed) * tgt_nonpadding
        if hparams['use_pitch_embed']:
            decoder_inp = decoder_inp + self.add_pitch(pitch_inp, f0, uv, tgt_nonpadding.squeeze(-1), ret)
        if hparams['use_energy_embed']:
            decoder_inp = decoder_inp + self.add_energy(pitch_inp, energy, ret)

        ret['decoder_inp'] = decoder_inp = (decoder_inp + spk_embed + img_spk_embed) * tgt_nonpadding

        if hparams['use_align_branch']:
            ret['align_mel_out'] = self.align_branch(decoder_inp)  # (B, T_mel, H)

        if skip_decoder:
            return ret
        ret['mel_out'] = self.run_decoder(decoder_inp, tgt_nonpadding, ret, infer=infer, **kwargs)
        return ret

    def get_video_text_att(self, text_hidden, video_hidden, ilens, olens, vid_mel_repeat_num):
        """
        inputs: video hidden and text hidden (B, T, H)
        return: repeated text seq (B, T_mel, H_txt),
                attention score (B, T_vid, T_txt),
                repeated target nonpadding mask, (B, T_mel, 1)
        """
        T_vid = video_hidden.size(1)
        pad_mask = make_pad_mask(ilens).to(text_hidden.device).unsqueeze(1)  # (B, 1, T_txt)
        tgt_padding_mask = make_pad_mask(olens).to(text_hidden.device)
        if hparams['text_video_att_type'] in ['gmm', 'lsa']:
            attn_scores = []
            for t in range(T_vid):
                scores = self.text_video_att(text_hidden, video_hidden[:, t, :], t, ilens)
                attn_scores.append(scores)  # (B x 1 x T_txt)
            attn_scores = torch.cat(attn_scores, 1)  # (B, T_vid, T_txt)
            context = attn_scores @ text_hidden  # (B, T_vid, T_txt) @ (B, T_txt, H_txt) = (B, T_vid, H_txt)
        elif hparams['text_video_att_type'] == 'add':
            attn_scores = self.text_video_att(text_hidden, video_hidden, mask=pad_mask)
            context = attn_scores @ text_hidden
        elif hparams['text_video_att_type'] == 'scaled_dot':
            context = self.text_video_att(video_hidden, text_hidden, text_hidden, mask=pad_mask)
            attn_scores = self.text_video_att.attn[:, 0]  # (batch, head, time1, time2)
        elif hparams['text_video_att_type'] == 'torch_dot':
            context, attn_scores = self.text_video_att(
                query=video_hidden.transpose(0, 1),
                key=text_hidden.transpose(0, 1),
                value=text_hidden.transpose(0, 1),
                key_padding_mask=pad_mask.squeeze(1),
                need_head_weights=True,
            )
            """
                Outputs:
            - attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size,
              E is the embedding dimension.
            - attn_output_weights: :math:`(N, L, S)` where N is the batch size,
              L is the target sequence length, S is the source sequence length.
            """
            context = context.transpose(0, 1)  # (T, B, H) -> (B, T, H)
        else:
            assert False
        attn_ks = text_hidden.new_tensor(olens) / text_hidden.new_tensor(ilens)
        diagonal_loss, diagonal_mask, entropy_loss = self.get_diagonal_loss(attn_scores, attn_ks, pad_mask.squeeze(1), tgt_padding_mask)
        vid_nonpadding = make_non_pad_mask(olens).float().to(text_hidden.device)[:, :, None]  # (B, T_vid, 1)
        if hparams['infer'] and hparams['use_mas']:
            print("| Use Monotonic Alignment Search!")
            assert attn_scores.shape[0] == 1
            mas_path = maximum_path(attn_scores.log().transpose(1, 2),
                                    mask=torch.ones_like(attn_scores).transpose(1, 2)).transpose(1, 2)
            hparams['mas_path'] = mas_path
            context = mas_path @ text_hidden  # (B, T_vid, T_txt) @ (B, T_txt, H_txt) = (B, T_vid, H_txt)

        context = context * vid_nonpadding
        # add query
        context = (context + F.dropout(video_hidden, p=0.5, training=self.training)) \
            if (hparams['add_pitch'] or hparams['add_query']) else context
        # repeat
        repeated_context = torch.repeat_interleave(context, vid_mel_repeat_num, dim=1)
        repeated_tgt_nonpadding = torch.repeat_interleave(vid_nonpadding, vid_mel_repeat_num, dim=1)
        return repeated_context, attn_scores, repeated_tgt_nonpadding, diagonal_loss, diagonal_mask, entropy_loss

    def get_diagonal_loss(self, attn, attn_ks, src_padding_mask, tgt_padding_mask):
        diagonal_mask = get_diagonal_mask(attn, attn_ks, band_width=hparams['diagonal_band_width'])
        if src_padding_mask is not None:
            attn = attn * (1 - src_padding_mask.float())[:, None, :]
        if tgt_padding_mask is not None:
            attn = attn * (1 - tgt_padding_mask.float())[:, :, None]
        diagonal_attn = attn * diagonal_mask
        diagonal_focus_rate = diagonal_attn.sum(-1).sum(-1) / attn.sum(-1).sum(-1)
        diagonal_loss = -diagonal_focus_rate.mean().log() if hparams['use_log_diag_loss'] else -diagonal_focus_rate.mean()
        entropy_loss = None
        if hparams['use_entropy_loss']:
            entropy_loss = - (attn * (attn + 1e-8).log()).sum(-1).sum() / (~tgt_padding_mask.bool()).sum()
            assert not entropy_loss.isnan(), entropy_loss
        return diagonal_loss, diagonal_mask, entropy_loss

    def add_energy(self, decoder_inp, energy, ret):
        decoder_inp = decoder_inp.detach() + hparams['predictor_grad'] * (decoder_inp - decoder_inp.detach())
        ret['energy_pred'] = energy_pred = self.energy_predictor(decoder_inp)[:, :, 0]
        if energy is None:
            energy = energy_pred
        energy = torch.clamp(energy * 256 // 4, max=255).long()
        energy_embed = self.energy_embed(energy)
        return energy_embed

    def add_pitch(self, decoder_inp, f0, uv, tgt_nonpadding, ret, encoder_out=None):
        decoder_inp = decoder_inp.detach() + hparams['predictor_grad'] * (decoder_inp - decoder_inp.detach())
        pitch_padding = tgt_nonpadding.eq(0)
        ret['pitch_pred'] = pitch_pred = self.pitch_predictor(decoder_inp)
        if f0 is None:
            f0 = pitch_pred[:, :, 0]
        if hparams['use_uv'] and uv is None:
            uv = pitch_pred[:, :, 1] > 0
        ret['f0_denorm'] = f0_denorm = denorm_f0(f0, uv, hparams, pitch_padding=pitch_padding)
        f0[pitch_padding] = 0
        pitch = f0_to_coarse(f0_denorm)  # start from 0
        pitch_embed = self.pitch_embed(pitch)
        return pitch_embed

    def run_decoder(self, decoder_inp, tgt_nonpadding, ret, infer, **kwargs):
        x = decoder_inp  # [B, T, H]
        x = self.decoder(x)  # [B, T, hidden_size]
        # ret['decoder_out'] = x * tgt_nonpadding
        x = self.mel_out(x)  # [B, T, H]
        return x * tgt_nonpadding

    def out2mel(self, out):
        return out
