from text.symbols import symbols 
import torch 
import torch.nn as nn 
from utils.tools import get_mask_from_lengths, sequence_mask 
from model.variance_adaptor import VarianceAdaptor   
from model.modules import PhonemePreNet, Encoder, Decoder, TextEncoder 
from model.variance_adaptor import VarianceAdaptor
import math 

class DPPTTS(nn.Module):
    """ DPPTTS backbone model"""

    def __init__(self, preprocess_config, model_config):
        super(DPPTTS, self).__init__()
        self.n_src_vocab = len(symbols) + 1   
        self.word_dim = model_config["encoder"]["encoder_hidden"]
        self.embedding = nn.Embedding(self.n_src_vocab, self.word_dim)
        self.prenet = PhonemePreNet(model_config)
        self.encoder = TextEncoder(model_config)
        self.variance_adaptor = VarianceAdaptor(preprocess_config, model_config)
        self.decoder = Decoder(model_config)
        self.mel_linear = nn.Linear(
            model_config["decoder"]["decoder_hidden"],
            preprocess_config["preprocessing"]["mel"]["n_mel_channels"],
        )

        self.speaker_emb = None
        if model_config["multi_speaker"]:
            with open(
                os.path.join(
                    preprocess_config["path"]["preprocessed_path"], "speakers.json"
                ),
                "r",
            ) as f:
                n_speaker = len(json.load(f))
            self.speaker_emb = nn.Embedding(
                n_speaker,
                model_config["transformer"]["encoder_hidden"],
            )

        nn.init.normal_(self.embedding.weight, 0.0, self.word_dim**-0.5)

    def forward(
        self,
        speakers,
        texts,
        src_lens,
        max_src_len,
        mels=None,
        mel_lens=None,
        max_mel_len=None,
        p_targets=None,
        e_targets=None,
        d_targets=None,
        p_control=1.0,
        e_control=1.0,
        d_control=1.0,
    ):
        
        src_masks = get_mask_from_lengths(src_lens, max_src_len)

        mel_masks = (
                get_mask_from_lengths(mel_lens, max_mel_len)
                if mel_lens is not None
                else None
            )

        texts = self.embedding(texts) * math.sqrt(self.word_dim)
        x = self.prenet(texts, src_masks)
        x = self.encoder(x, src_lens)

        (
            output,
            p_predictions,
            e_predictions,
            d_predictions,
            d_rounded,
            mel_lens,
            mel_masks,
        ) = self.variance_adaptor(
            x,
            src_masks,
            mel_masks,
            max_mel_len,
            p_targets,
            e_targets,
            d_targets,
            p_control,
            e_control,
            d_control,
        )
        
        output, mel_masks = self.decoder(output, mel_masks)
        output = self.mel_linear(output)
        
        return (
                output, 
                p_predictions,
                e_predictions,
                d_predictions,
                d_rounded,
                src_masks,
                mel_masks,
                src_lens,
                mel_lens,
            )

    def encode_text(self, texts, src_lens, max_src_len):
        src_masks = get_mask_from_lengths(src_lens, max_src_len)

        texts = self.embedding(texts) * math.sqrt(self.word_dim)
        x = self.prenet(texts, src_masks)
        x = self.encoder(x, src_lens)

        return x, src_masks 