from utils.tools import get_mask_from_lengths
from utils.dpp_tools import dpp_inference
import torch
import torch.nn as nn 
import math 

class DPP_helper(nn.Module):
    def __init__(self, embedding, prenet, encoder, variance_adaptor, decoder, mel_linear):
        super(DPP_helper, self).__init__()
        self.embedding = embedding 
        self.prenet = prenet 
        self.encoder = encoder 
        self.variance_adaptor = variance_adaptor
        self.decoder = decoder 
        self.mel_linear = mel_linear

    def adapt(self, h_seq, h_mask, duration=None, pitch=None, noise_scale=None):
        predictions = self.variance_adaptor.inference(h_seq, h_mask, duration, pitch)

        return predictions 

    def adapt2(self, h_seq, h_mask, pitch=None):
        x = self.variance_adaptor.inference2(h_seq, h_mask, pitch=pitch)

        return x 

    def expand_seq(self, h_seq, duration, src_mask):
        x, mel_lens, mel_masks = self.variance_adaptor.expand(h_seq, duration, src_mask)

        return x, mel_lens, mel_masks 

    def decode(self, x, mel_masks):
        output, _ = self.decoder(x, mel_masks)
        output = self.mel_linear(output)

        return output 

    def inference(self, kernel, vector, seq,  np_ids, num_cw, half=False):
        '''vector: after PDM shape=[B,nc+2,t,1] , seq: before PDM shape=[B,t]'''
        step = 1 if not half else 2 
        b = vector.size(0)
        argmax_idxs = dpp_inference(kernel, num_cw=num_cw)
        inference_result = torch.stack([vector[i][argmax_idxs[i]+2] for i in range(0,len(argmax_idxs), step)])  # shape [B,T,1] or [B//2 T, 1]
        inference_result, np_ids = inference_result.squeeze(-1), np_ids.squeeze(1)

        seq[1] = seq[0].clone() # shape = [B,T]
        for i in range(0, b//step):
            end, start = np_ids[i*step][1] * 2 + 1 , np_ids[i*step][0] * 2 + 1
            length = end - start + 1 
            seq[0, start:end+1] = inference_result[i,:length]

        return seq 
