from modules.operations import *

DEFAULT_MAX_SOURCE_POSITIONS = 2000
DEFAULT_MAX_TARGET_POSITIONS = 2000

class TransducerEncoder(nn.Module):
    def __init__(self, arch, embed_tokens, last_ln=True):
        super().__init__()
        self.arch = arch
        self.num_layers = hparams['text_enc_layers']
        self.hidden_size = hparams['hidden_size']
        self.embed_tokens = embed_tokens
        self.padding_idx = embed_tokens.padding_idx
        embed_dim = embed_tokens.embedding_dim
        self.dropout = hparams['dropout']
        self.embed_scale = math.sqrt(embed_dim)
        self.max_source_positions = DEFAULT_MAX_SOURCE_POSITIONS
        self.embed_positions = SinusoidalPositionalEmbedding(
            embed_dim, self.padding_idx,
            init_size=self.max_source_positions + self.padding_idx + 1,
        )
        self.layers = nn.ModuleList([])
        self.layers.extend([
            TransducerEncoderLayer(self.arch[i], self.hidden_size, self.dropout)
            for i in range(self.num_layers)
        ])
        self.last_ln = last_ln
        if last_ln:
            self.layer_norm = LayerNorm(embed_dim)

    def forward_embedding(self, src_tokens):
        # embed tokens and positions
        embed = self.embed_scale * self.embed_tokens(src_tokens)
        positions = self.embed_positions(src_tokens)
        # x = self.prenet(x)
        x = embed + positions
        x = F.dropout(x, p=self.dropout, training=self.training)
        return x, embed

    def forward(self, src_tokens):
        """

        :param src_tokens: [B, T]
        :return: {
            'encoder_out': [T x B x C]
            'encoder_padding_mask': [B x T]
            'encoder_embedding': [B x T x C]
            'attn_w': []
        }
        """
        x, encoder_embedding = self.forward_embedding(src_tokens)

        # B x T x C -> T x B x C
        x = x.transpose(0, 1)

        # compute padding mask
        encoder_padding_mask = src_tokens.eq(self.padding_idx).data

        # encoder layers
        for layer in self.layers:
            x = layer(x, encoder_padding_mask=encoder_padding_mask)

        if self.last_ln:
            x = self.layer_norm(x)
            x = x * (1 - encoder_padding_mask.float()).transpose(0, 1)[..., None]
        return {
            'encoder_out': x,  # T x B x C
            'encoder_padding_mask': encoder_padding_mask,  # B x T
            'encoder_embedding': encoder_embedding,  # B x T x C
            'attn_w': []
        }


class TransducerDecoder(nn.Module):
    def __init__(self, arch, padding_idx=0, num_layers=None, causal=True, dropout=None, out_dim=None):
        super().__init__()
        self.arch = arch
        self.num_layers = hparams['speech_enc_layers'] if num_layers is None else num_layers
        self.hidden_size = hparams['hidden_size']
        self.prenet_hidden_size = hparams['prenet_hidden_size']
        self.padding_idx = padding_idx
        self.causal = causal
        self.dropout = hparams['dropout'] if dropout is None else dropout
        self.in_dim = hparams['audio_num_mel_bins']
        self.out_dim = hparams['audio_num_mel_bins'] + 1 if out_dim is None else out_dim
        self.max_target_positions = DEFAULT_MAX_TARGET_POSITIONS
        self.embed_positions = SinusoidalPositionalEmbedding(
            self.hidden_size, self.padding_idx,
            init_size=self.max_target_positions + self.padding_idx + 1,
        )
        self.layers = nn.ModuleList([])
        self.layers.extend([
            TransducerDecoderLayer(self.arch[i], self.hidden_size, self.dropout)
            for i in range(self.num_layers)
        ])
        self.layer_norm = LayerNorm(self.hidden_size)
        self.prenet_fc1 = Linear(self.in_dim, self.prenet_hidden_size)
        self.prenet_fc2 = Linear(self.prenet_hidden_size, self.prenet_hidden_size)
        self.prenet_fc3 = Linear(self.prenet_hidden_size, self.hidden_size, bias=False)

    def forward_prenet(self, x):
        mask = x.abs().sum(-1, keepdim=True).ne(0).float()

        prenet_dropout = 0.5
        # prenet_dropout = random.uniform(0, 0.5) if self.training else 0
        x = self.prenet_fc1(x)
        x = F.relu(x)
        x = F.dropout(x, prenet_dropout, training=True)
        x = self.prenet_fc2(x)
        x = F.relu(x)
        x = F.dropout(x, prenet_dropout, training=True)
        x = self.prenet_fc3(x)
        x = F.relu(x)
        x = x * mask
        return x

    def forward(
            self,
            prev_output_mels,  # B x T x 80
            target_mels=None,
            incremental_state=None,
    ):
        # embed positions
        if incremental_state is not None:
            positions = self.embed_positions(
                prev_output_mels.abs().sum(-1),
                incremental_state=incremental_state
            )
            prev_output_mels = prev_output_mels[:, -1:, :]
            positions = positions[:, -1:, :]
            self_attn_padding_mask = None
        else:
            prev_output_mels_ = prev_output_mels.clone()
            prev_output_mels_[:, 0] = prev_output_mels_[:, 0] + 1.
            positions = self.embed_positions(
                prev_output_mels_.abs().sum(-1),
                incremental_state=incremental_state
            )
            self_attn_padding_mask = prev_output_mels_.abs().sum(-1).eq(0).data

        # convert mels through prenet
        x = self.forward_prenet(prev_output_mels)
        # embed positions
        x += positions
        x = F.dropout(x, p=self.dropout, training=self.training)

        # B x T x C -> T x B x C
        x = x.transpose(0, 1)

        # decoder layers
        for layer in self.layers:
            if incremental_state is None and self.causal:
                self_attn_mask = self.buffered_future_mask(x)
            else:
                self_attn_mask = None

            x = layer(
                x,
                incremental_state=incremental_state,
                self_attn_mask=self_attn_mask,
                self_attn_padding_mask=self_attn_padding_mask
            )

        x = self.layer_norm(x)

        # T x B x C -> B x T x C
        x = x.transpose(0, 1)

        return x

    def buffered_future_mask(self, tensor):
        dim = tensor.size(0)
        if (
                not hasattr(self, '_future_mask')
                or self._future_mask is None
                or self._future_mask.device != tensor.device
                or self._future_mask.size(0) < dim
        ):
            self._future_mask = torch.triu(utils.fill_with_neg_inf(tensor.new(dim, dim)), 1)
        return self._future_mask[:dim, :dim]


class JointNetwork(nn.Module):
    """Transducer joint network module. concate Joint_v1"""
    def __init__(self, out_dim, hidden_size, n_layers, dropout, activation_type='ReLU'):
        super().__init__()
        self.text_enc_linear = Linear(hidden_size, hidden_size)
        self.speech_enc_linear = Linear(hidden_size, hidden_size)
        self.joint_linear1 = Linear(2*hidden_size, 4*hidden_size)
        self.joint_activation = getattr(torch.nn, activation_type)()
        self.joint_linear2 = Linear(4*hidden_size, hidden_size)

        self.layers = nn.ModuleList([])
        self.layers.extend([
            JointNetFFNLayer(hidden_size, 4*hidden_size, dropout, activation_type)
            for i in range(n_layers)
            ])
        self.layer_norm = LayerNorm(hidden_size)
        self.out_linear = Linear(hidden_size, out_dim)
        
    def forward(self, text_encoder_outputs, speech_encoder_outputs):
        """
        Args:
            text_encoder_outputs (tensor): (B, T, H)
            speech_encoder_outputs (tensor): (B, U, H)
        Returns:
            (tensor): (B, T, U, H)
        """
        B, T, H = text_encoder_outputs.size()
        B, U, H = speech_encoder_outputs.size()
        text_encoder_outputs = self.text_enc_linear(text_encoder_outputs).unsqueeze(2).repeat(1,1,U,1)
        speech_encoder_outputs = self.speech_enc_linear(speech_encoder_outputs).unsqueeze(1).repeat(1,T,1,1)
        joint_outputs = self.joint_linear1(torch.cat((text_encoder_outputs, speech_encoder_outputs), dim=-1))
        joint_outputs = self.joint_linear2(self.joint_activation(joint_outputs))
        
        x = joint_outputs
        for layer in self.layers:
            x = layer(x)
        x = self.layer_norm(x)
        x = self.out_linear(x)
        return x


class SpeechTransducer(nn.Module):
    def __init__(self, arch, dictionary):
        super(SpeechTransducer, self).__init__()
        if isinstance(arch, str):
            self.arch = list(map(int, arch.strip().split()))
        else:
            assert isinstance(arch, (list, tuple))
            self.arch = arch
        self.text_enc_arch = self.arch[:hparams['text_enc_layers']]
        self.speech_enc_arch = self.arch[hparams['text_enc_layers']:]
        self.dictionary = dictionary
        self.vocab = len(dictionary)
        self.padding_idx = dictionary.pad()
        self.hidden_size = hparams['hidden_size']
        self.n_mels = hparams['audio_num_mel_bins']
        self.tts_joint_params = hparams['tts_joint_network_params']
        self.encoder_embed_tokens = self.build_embedding(self.dictionary, self.hidden_size)

        self.text_encoder = self.build_text_encoder()
        self.speech_encoder = self.build_speech_encoder()
        self.tts_joint_network = self.build_tts_joint_network()

    def build_embedding(self, dictionary, embed_dim):
        num_embeddings = len(dictionary)
        emb = Embedding(num_embeddings, embed_dim, self.padding_idx)
        return emb

    def build_text_encoder(self):
        return TransducerEncoder(self.text_enc_arch, self.encoder_embed_tokens)

    def build_speech_encoder(self):
        return TransducerDecoder(self.speech_enc_arch, padding_idx=self.padding_idx,
                                 causal=True)

    def build_tts_joint_network(self):
        return JointNetwork(out_dim=self.n_mels+1, **self.tts_joint_params)

    def forward_text_encoder(self, src_tokens, *args, **kwargs):
        return self.text_encoder(src_tokens)

    def forward_speech_encoder(self, prev_output_mels, incremental_state=None):
        decoder_output = self.speech_encoder(
            prev_output_mels, incremental_state=incremental_state)
        return decoder_output

    def forward(self, src_tokens, prev_output_mels, target_mels, *args, **kwargs):
        text_encoder_outputs = self.forward_text_encoder(src_tokens)
        text_encoder_outputs = text_encoder_outputs['encoder_out'].transpose(0, 1)
        speech_encoder_outputs = self.speech_encoder(prev_output_mels, target_mels=target_mels)
        tts_joint_network_outputs = self.tts_joint_network(text_encoder_outputs, speech_encoder_outputs)
        return tts_joint_network_outputs


def Embedding(num_embeddings, embedding_dim, padding_idx):
    m = nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx)
    nn.init.normal_(m.weight, mean=0, std=embedding_dim ** -0.5)
    nn.init.constant_(m.weight[padding_idx], 0)
    return m

class TransducerEncoderLayer(nn.Module):
    def __init__(self, layer, hidden_size, dropout):
        super().__init__()
        self.layer = layer
        self.hidden_size = hidden_size
        self.dropout = dropout
        if layer == 13:
            self.op = OPERATIONS_ENCODER[layer](hidden_size, dropout, hparams['gaus_bias'], hparams['gaus_tao'])
        else:
            self.op = OPERATIONS_ENCODER[layer](hidden_size, dropout)

    def forward(self, x, **kwargs):
        return self.op(x, **kwargs)

class TransducerDecoderLayer(nn.Module):
    def __init__(self, layer, hidden_size, dropout):
        super().__init__()
        self.layer = layer
        self.hidden_size = hidden_size
        self.dropout = dropout
        self.op = OPERATIONS_DECODER[layer](hidden_size, dropout)

    def forward(self, x, **kwargs):
        return self.op(x, **kwargs)

    def clear_buffer(self, *args):
        return self.op.clear_buffer(*args)

    def set_buffer(self, *args):
        return self.op.set_buffer(*args)

class JointNetFFNLayer(nn.Module):
    """docstring for JointNetFFNLayer"""
    def __init__(self, hidden_size, filter_size, dropout, activation_type='ReLU'):
        super(JointNetFFNLayer, self).__init__()
        self.dropout = dropout
        self.layer_norm = LayerNorm(hidden_size)
        self.ffn1 = Linear(hidden_size, filter_size)
        self.activation = getattr(torch.nn, activation_type)()
        self.ffn2 = Linear(filter_size, hidden_size)

    def forward(self, x):
        # x: B x T x U x H
        residual = x
        x = self.layer_norm(x)

        x = self.ffn1(x)
        x = self.activation(x)
        x = F.dropout(x, self.dropout, training=self.training)
        x = self.ffn2(x)

        x = F.dropout(x, self.dropout, training=self.training)
        x = residual + x

        return x
