from model.subblocks import FCBlock, Conv1DBlock   
from utils.tools import get_sinusoid_encoding_table
import torch
import torch.nn as nn 
from .blocks import FFTBlock, FFTBlock2 
from .subblocks import Mish
from utils.tools import sequence_mask

class PhonemePreNet(nn.Module):
	""" Phoneme prenet """

	def __init__(self, config):
		super(PhonemePreNet, self).__init__()
		d_model = config["encoder"]["encoder_hidden"]
		kernel_size = config["prenet"]["conv_kernel_size"]
		dropout = config["prenet"]["dropout"]

		self.prenet_layer = nn.Sequential(
			Conv1DBlock(
				d_model, d_model, kernel_size, activation=Mish(), dropout=dropout
			),
			Conv1DBlock(
				d_model, d_model, kernel_size, activation=Mish(), dropout=dropout
			),
			FCBlock(d_model, d_model, dropout=dropout),
		)

	def forward(self, x, mask=None):
		residual = x
		x = self.prenet_layer(x)
		if mask is not None:
			x = x.masked_fill(mask.unsqueeze(-1), 0)
		x = residual + x 
		return x 

class Encoder(nn.Module):
    """ Encoder """

    def __init__(self, config):
        super(Encoder, self).__init__()

        n_position = config["max_seq_len"] +1 
        d_word_vec = config["encoder"]["encoder_hidden"]
        n_layers = config["encoder"]["encoder_layer"]
        n_head = config["encoder"]["encoder_head"]
        d_k = d_v = (
            config["encoder"]["encoder_hidden"]
            // config["encoder"]["encoder_head"]
        )
        d_model = config["encoder"]["encoder_hidden"]
        d_inner = config["encoder"]["conv_filter_size"]
        kernel_size = config["encoder"]["conv_kernel_size"]
        dropout = config["encoder"]["dropout"]

        self.max_seq_len = config["max_seq_len"]
        self.d_model = d_model
        
        self.position_enc = nn.Parameter(
            get_sinusoid_encoding_table(n_position, d_word_vec).unsqueeze(0),
            requires_grad=False,
            )

        self.layer_stack = nn.ModuleList(
            [
                FFTBlock(
                    d_model, n_head, d_k, d_v, d_inner, kernel_size, dropout=dropout
                )
                for _ in range(n_layers)
            ]
        )

    def forward(self, src_seq, mask, return_attns=False):

        enc_slf_attn_list = []
        batch_size, max_len = src_seq.shape[0], src_seq.shape[1]

        # -- Prepare masks
        slf_attn_mask = mask.unsqueeze(1).expand(-1, max_len, -1)
 
        # -- Forward
        if not self.training and src_seq.shape[1] > self.max_seq_len:
            enc_output = src_seq + get_sinusoid_encoding_table(
                src_seq.shape[1], self.d_model
            )[: src_seq.shape[1], :].unsqueeze(0).expand(batch_size, -1, -1).to(
                src_seq.device
            )
        else:
            enc_output = src_seq + self.position_enc[
                :, :max_len, :
            ].expand(batch_size, -1, -1)

        for enc_layer in self.layer_stack:
            enc_output, enc_slf_attn = enc_layer(
                enc_output, mask=mask, slf_attn_mask=slf_attn_mask
            )
            if return_attns:
                enc_slf_attn_list += [enc_slf_attn]

        return enc_output


class TextEncoder(nn.Module):
  def __init__(self, config):
    super().__init__()
    self.hidden_channels = config["encoder"]["encoder_hidden"]
    self.filter_channels = config["encoder"]["conv_filter_size"]
    self.n_heads = config["encoder"]["encoder_head"]
    self.n_layers = config["encoder"]["encoder_layer"]
    self.kernel_size = config["encoder"]["conv_kernel_size"]
    self.p_dropout = config["encoder"]["dropout"]

    self.encoder = FFTBlock2(
      self.hidden_channels,
      self.filter_channels,
      self.n_heads,
      self.n_layers,
      self.kernel_size,
      self.p_dropout)

  def forward(self, x, x_lengths):
    x = torch.transpose(x, 1, -1) # [b, h, t]
    x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype)

    x = self.encoder(x * x_mask, x_mask)
    x = torch.transpose(x, 1, -1) # [b, t, h]
    return x

class Decoder(nn.Module):
    """ Mel decoder """

    def __init__(self, config):
        super(Decoder, self).__init__()

        n_position = config["max_seq_len"] + 1
        d_word_vec = config["decoder"]["decoder_hidden"]
        n_layers = config["decoder"]["decoder_layer"]
        n_head = config["decoder"]["decoder_head"]
        d_k = d_v = (
            config["decoder"]["decoder_hidden"]
            // config["decoder"]["decoder_head"]
        )
        d_model = config["decoder"]["decoder_hidden"]
        d_inner = config["decoder"]["conv_filter_size"]
        kernel_size = config["decoder"]["conv_kernel_size"]
        dropout = config["decoder"]["dropout"]

        self.max_seq_len = config["max_seq_len"]
        self.d_model = d_model

        self.position_enc = nn.Parameter(
            get_sinusoid_encoding_table(n_position, d_word_vec).unsqueeze(0),
            requires_grad=False,
        )

        self.layer_stack = nn.ModuleList(
            [
                FFTBlock(
                    d_model, n_head, d_k, d_v, d_inner, kernel_size, dropout=dropout
                )
                for _ in range(n_layers)
            ]
        )

    def forward(self, enc_seq, mask, return_attns=False):

        dec_slf_attn_list = []
        batch_size, max_len = enc_seq.shape[0], enc_seq.shape[1]

        # -- Forward
        if not self.training and enc_seq.shape[1] > self.max_seq_len:
            # -- Prepare masks
            slf_attn_mask = mask.unsqueeze(1).expand(-1, max_len, -1)
            dec_output = enc_seq + get_sinusoid_encoding_table(
                enc_seq.shape[1], self.d_model
            )[: enc_seq.shape[1], :].unsqueeze(0).expand(batch_size, -1, -1).to(
                enc_seq.device
            )
        else:
            max_len = min(max_len, self.max_seq_len)

            # -- Prepare masks
            slf_attn_mask = mask.unsqueeze(1).expand(-1, max_len, -1)
            dec_output = enc_seq[:, :max_len, :] + self.position_enc[
                :, :max_len, :
            ].expand(batch_size, -1, -1)
            mask = mask[:, :max_len]
            slf_attn_mask = slf_attn_mask[:, :, :max_len]

        for dec_layer in self.layer_stack:
            dec_output, dec_slf_attn = dec_layer(
                dec_output, mask=mask, slf_attn_mask=slf_attn_mask
            )
            if return_attns:
                dec_slf_attn_list += [dec_slf_attn]

        return dec_output, mask