# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import logging

import torch
from torch import nn

from fairseq.models import (
    FairseqEncoder,
    FairseqEncoderModel,
    register_model,
    register_model_architecture,
)
from fairseq.modules import (
    LayerNorm,
    PositionalEmbedding,
    FairseqDropout,
    MultiheadAttention,
)
from fairseq import utils
from fairseq.data.data_utils import lengths_to_padding_mask
from fairseq.models.text_to_speech.tacotron2 import Postnet


logger = logging.getLogger(__name__)


def model_init(m):
    if isinstance(m, nn.Conv1d):
        nn.init.xavier_uniform_(m.weight, torch.nn.init.calculate_gain("relu"))


def Embedding(num_embeddings, embedding_dim, padding_idx=None):
    m = nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx)
    nn.init.normal_(m.weight, mean=0, std=embedding_dim ** -0.5)
    return m


class PositionwiseFeedForward(nn.Module):
    def __init__(self, in_dim, hidden_dim, kernel_size, dropout):
        super().__init__()
        self.ffn = nn.Sequential(
            nn.Conv1d(
                in_dim,
                hidden_dim,
                kernel_size=kernel_size,
                padding=(kernel_size - 1) // 2,
            ),
            nn.ReLU(),
            nn.Conv1d(
                hidden_dim,
                in_dim,
                kernel_size=kernel_size,
                padding=(kernel_size - 1) // 2,
            ),
        )
        self.layer_norm = LayerNorm(in_dim)
        self.dropout = self.dropout_module = FairseqDropout(
            p=dropout, module_name=self.__class__.__name__
        )

    def forward(self, x):
        # B x T x C
        residual = x
        x = self.ffn(x.transpose(1, 2)).transpose(1, 2)
        x = self.dropout(x)
        return self.layer_norm(x + residual)


class FFTLayer(torch.nn.Module):
    def __init__(
        self, embed_dim, n_heads, hidden_dim, kernel_size, dropout, attention_dropout
    ):
        super().__init__()
        self.self_attn = MultiheadAttention(
            embed_dim, n_heads, dropout=attention_dropout, self_attention=True
        )
        self.layer_norm = LayerNorm(embed_dim)
        self.ffn = PositionwiseFeedForward(
            embed_dim, hidden_dim, kernel_size, dropout=dropout
        )

    def forward(self, x, padding_mask=None):
        # B x T x C
        residual = x
        x = x.transpose(0, 1)
        x, _ = self.self_attn(
            query=x, key=x, value=x, key_padding_mask=padding_mask, need_weights=False
        )
        x = x.transpose(0, 1)
        x = self.layer_norm(x + residual)
        return self.ffn(x)


class LengthRegulator(nn.Module):
    def forward(self, x, durations):
        # x: B x T x C
        out_lens = durations.sum(dim=1)
        max_len = out_lens.max()
        bsz, seq_len, dim = x.size()
        out = x.new_zeros((bsz, max_len, dim))

        for b in range(bsz):
            indices = []
            for t in range(seq_len):
                indices.extend([t] * utils.item(durations[b, t]))
            indices = torch.tensor(indices, dtype=torch.long).to(x.device)
            out_len = utils.item(out_lens[b])
            out[b, :out_len] = x[b].index_select(0, indices)

        return out, out_lens


class VariancePredictor(nn.Module):
    def __init__(self, args):
        super().__init__()
        self.conv1 = nn.Sequential(
            nn.Conv1d(
                args.encoder_embed_dim,
                args.var_pred_hidden_dim,
                kernel_size=args.var_pred_kernel_size,
                padding=(args.var_pred_kernel_size - 1) // 2,
            ),
            nn.ReLU(),
        )
        self.ln1 = nn.LayerNorm(args.var_pred_hidden_dim)
        self.dropout_module = FairseqDropout(
            p=args.var_pred_dropout, module_name=self.__class__.__name__
        )
        self.conv2 = nn.Sequential(
            nn.Conv1d(
                args.var_pred_hidden_dim,
                args.var_pred_hidden_dim,
                kernel_size=args.var_pred_kernel_size,
                padding=1,
            ),
            nn.ReLU(),
        )
        self.ln2 = nn.LayerNorm(args.var_pred_hidden_dim)
        self.proj = nn.Linear(args.var_pred_hidden_dim, 1)

    def forward(self, x):
        # Input: B x T x C; Output: B x T
        x = self.conv1(x.transpose(1, 2)).transpose(1, 2)
        x = self.dropout_module(self.ln1(x))
        x = self.conv2(x.transpose(1, 2)).transpose(1, 2)
        x = self.dropout_module(self.ln2(x))
        return self.proj(x).squeeze(dim=2)


class VarianceAdaptor(nn.Module):
    def __init__(self, args):
        super().__init__()
        self.args = args
        self.length_regulator = LengthRegulator()
        self.duration_predictor = VariancePredictor(args)
        self.pitch_predictor = VariancePredictor(args)
        self.energy_predictor = VariancePredictor(args)

        n_bins, steps = self.args.var_pred_n_bins, self.args.var_pred_n_bins - 1
        self.pitch_bins = torch.linspace(args.pitch_min, args.pitch_max, steps)
        self.embed_pitch = Embedding(n_bins, args.encoder_embed_dim)
        self.energy_bins = torch.linspace(args.energy_min, args.energy_max, steps)
        self.embed_energy = Embedding(n_bins, args.encoder_embed_dim)

    def get_pitch_emb(self, x, tgt=None, factor=1.0):
        out = self.pitch_predictor(x)
        bins = self.pitch_bins.to(x.device)
        if tgt is None:
            out = out * factor
            emb = self.embed_pitch(torch.bucketize(out, bins))
        else:
            emb = self.embed_pitch(torch.bucketize(tgt, bins))
        return out, emb

    def get_energy_emb(self, x, tgt=None, factor=1.0):
        out = self.energy_predictor(x)
        bins = self.energy_bins.to(x.device)
        if tgt is None:
            out = out * factor
            emb = self.embed_energy(torch.bucketize(out, bins))
        else:
            emb = self.embed_energy(torch.bucketize(tgt, bins))
        return out, emb

    def forward(
        self,
        x,
        padding_mask,
        durations=None,
        pitches=None,
        energies=None,
        d_factor=1.0,
        p_factor=1.0,
        e_factor=1.0,
    ):
        # x: B x T x C
        log_dur_out = self.duration_predictor(x)
        dur_out = torch.clamp(
            torch.round((torch.exp(log_dur_out) - 1) * d_factor).long(), min=0
        )
        dur_out.masked_fill_(padding_mask, 0)

        pitch_out, pitch_emb = self.get_pitch_emb(x, pitches, p_factor)
        x = x + pitch_emb
        energy_out, energy_emb = self.get_energy_emb(x, energies, e_factor)
        x = x + energy_emb

        x, out_lens = self.length_regulator(
            x, dur_out if durations is None else durations
        )

        return x, out_lens, log_dur_out, pitch_out, energy_out


class FastSpeech2Encoder(FairseqEncoder):
    def __init__(self, args, src_dict, embed_speaker):
        super().__init__(src_dict)
        self.args = args
        self.padding_idx = src_dict.pad()
        self.n_frames_per_step = args.n_frames_per_step
        self.out_dim = args.output_frame_dim * args.n_frames_per_step

        self.embed_speaker = embed_speaker
        self.spk_emb_proj = None
        if embed_speaker is not None:
            self.spk_emb_proj = nn.Linear(
                args.encoder_embed_dim + args.speaker_embed_dim, args.encoder_embed_dim
            )

        self.dropout_module = FairseqDropout(
            p=args.dropout, module_name=self.__class__.__name__
        )
        self.embed_tokens = Embedding(
            len(src_dict), args.encoder_embed_dim, padding_idx=self.padding_idx
        )

        self.embed_positions = PositionalEmbedding(
            args.max_source_positions, args.encoder_embed_dim, self.padding_idx
        )
        self.pos_emb_alpha = nn.Parameter(torch.ones(1))
        self.dec_pos_emb_alpha = nn.Parameter(torch.ones(1))

        self.encoder_fft_layers = nn.ModuleList(
            FFTLayer(
                args.encoder_embed_dim,
                args.encoder_attention_heads,
                args.fft_hidden_dim,
                args.fft_kernel_size,
                dropout=args.dropout,
                attention_dropout=args.attention_dropout,
            )
            for _ in range(args.encoder_layers)
        )

        self.var_adaptor = VarianceAdaptor(args)

        self.decoder_fft_layers = nn.ModuleList(
            FFTLayer(
                args.decoder_embed_dim,
                args.decoder_attention_heads,
                args.fft_hidden_dim,
                args.fft_kernel_size,
                dropout=args.dropout,
                attention_dropout=args.attention_dropout,
            )
            for _ in range(args.decoder_layers)
        )

        self.out_proj = nn.Linear(args.decoder_embed_dim, self.out_dim)

        self.postnet = None
        if args.add_postnet:
            self.postnet = Postnet(
                self.out_dim,
                args.postnet_conv_dim,
                args.postnet_conv_kernel_size,
                args.postnet_layers,
                args.postnet_dropout,
            )

        self.apply(model_init)

    def forward(
        self,
        src_tokens,
        src_lengths=None,
        speaker=None,
        durations=None,
        pitches=None,
        energies=None,
        **kwargs
    ):
        x = self.embed_tokens(src_tokens)

        enc_padding_mask = src_tokens.eq(self.padding_idx)
        x += self.pos_emb_alpha * self.embed_positions(enc_padding_mask)
        x = self.dropout_module(x)

        for layer in self.encoder_fft_layers:
            x = layer(x, enc_padding_mask)

        if self.embed_speaker is not None:
            bsz, seq_len, _ = x.size()
            emb = self.embed_speaker(speaker).expand(bsz, seq_len, -1)
            x = self.spk_emb_proj(torch.cat([x, emb], dim=2))

        x, out_lens, log_dur_out, pitch_out, energy_out = self.var_adaptor(
            x, enc_padding_mask, durations, pitches, energies
        )

        dec_padding_mask = lengths_to_padding_mask(out_lens)
        x += self.dec_pos_emb_alpha * self.embed_positions(dec_padding_mask)
        for layer in self.decoder_fft_layers:
            x = layer(x, dec_padding_mask)

        x = self.out_proj(x)
        x_post = None
        if self.postnet is not None:
            x_post = x + self.postnet(x)
        return x, x_post, out_lens, log_dur_out, pitch_out, energy_out


@register_model("fastspeech2")
class FastSpeech2Model(FairseqEncoderModel):
    """
    Implementation for https://arxiv.org/abs/2006.04558
    """

    NON_AUTOREGRESSIVE = True

    @staticmethod
    def add_args(parser):
        parser.add_argument("--dropout", type=float)
        parser.add_argument("--output-frame-dim", type=int)
        parser.add_argument("--speaker-embed-dim", type=int)
        # FFT blocks
        parser.add_argument("--fft-hidden-dim", type=int)
        parser.add_argument("--fft-kernel-size", type=int)
        parser.add_argument("--attention-dropout", type=float)
        parser.add_argument("--encoder-layers", type=int)
        parser.add_argument("--encoder-embed-dim", type=int)
        parser.add_argument("--encoder-attention-heads", type=int)
        parser.add_argument("--decoder-layers", type=int)
        parser.add_argument("--decoder-embed-dim", type=int)
        parser.add_argument("--decoder-attention-heads", type=int)
        # variance predictor
        parser.add_argument("--var-pred-n-bins", type=int)
        parser.add_argument("--var-pred-hidden-dim", type=int)
        parser.add_argument("--var-pred-kernel-size", type=int)
        parser.add_argument("--var-pred-dropout", type=float)
        # postnet
        parser.add_argument("--add-postnet", action="store_true")
        parser.add_argument("--postnet-dropout", type=float)
        parser.add_argument("--postnet-layers", type=int)
        parser.add_argument("--postnet-conv-dim", type=int)
        parser.add_argument("--postnet-conv-kernel-size", type=int)

    def __init__(self, encoder, args, src_dict):
        super().__init__(encoder)
        self._num_updates = 0

        out_dim = args.output_frame_dim * args.n_frames_per_step
        self.ctc_proj = None
        if getattr(args, "ctc_weight", 0.0) > 0.0:
            self.ctc_proj = nn.Linear(out_dim, len(src_dict))

    @classmethod
    def build_model(cls, args, task):
        embed_speaker = task.get_speaker_embeddings(args)
        encoder = FastSpeech2Encoder(args, task.src_dict, embed_speaker)
        return cls(encoder, args, task.src_dict)

    def set_num_updates(self, num_updates):
        super().set_num_updates(num_updates)
        self._num_updates = num_updates

    def get_normalized_probs(self, net_output, log_probs, sample=None):
        logits = self.ctc_proj(net_output[0])
        if log_probs:
            return utils.log_softmax(logits.float(), dim=-1)
        else:
            return utils.softmax(logits.float(), dim=-1)


@register_model_architecture("fastspeech2", "fastspeech2")
def base_architecture(args):
    args.dropout = getattr(args, "dropout", 0.2)
    args.output_frame_dim = getattr(args, "output_frame_dim", 80)
    args.speaker_embed_dim = getattr(args, "speaker_embed_dim", 64)
    # FFT blocks
    args.fft_hidden_dim = getattr(args, "fft_hidden_dim", 1024)
    args.fft_kernel_size = getattr(args, "fft_kernel_size", 9)
    args.attention_dropout = getattr(args, "attention_dropout", 0.0)
    args.encoder_layers = getattr(args, "encoder_layers", 4)
    args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 256)
    args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 2)
    args.decoder_layers = getattr(args, "decoder_layers", 4)
    args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 256)
    args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 2)
    # variance predictor
    args.var_pred_n_bins = getattr(args, "var_pred_n_bins", 256)
    args.var_pred_hidden_dim = getattr(args, "var_pred_hidden_dim", 256)
    args.var_pred_kernel_size = getattr(args, "var_pred_kernel_size", 3)
    args.var_pred_dropout = getattr(args, "var_pred_dropout", 0.5)
    # postnet
    args.add_postnet = getattr(args, "add_postnet", False)
    args.postnet_dropout = getattr(args, "postnet_dropout", 0.5)
    args.postnet_layers = getattr(args, "postnet_layers", 5)
    args.postnet_conv_dim = getattr(args, "postnet_conv_dim", 512)
    args.postnet_conv_kernel_size = getattr(args, "postnet_conv_kernel_size", 5)
