import torch
import torch.nn as nn
import torch.nn.functional as F
from utils.hparams import hparams
from losses.compute_alphas_torch import ComputerAlphas


class RNNTLossTTS(nn.Module):
    def __init__(self):
        super(RNNTLossTTS, self).__init__()
        self.get_alphas = ComputerAlphas()
        self.tau = hparams['tau']

    def forward(self, acts, labels, act_lens, label_lens, alphas_gt, phis_gt):
        """
        Args:
            acts (tensor): jointnet outputs, (B, T, U, n_mels+1)
                           where B is the minibatch size, T is the maximum number of 
                           input tokens U is the maximum number of output frames.
            labels (tensor): target mels, (B, U, n_mels)
            act_len (tensor): the length of each input
                              for each example in the minibatch, (B,)
            label_lens (tensor): the length of each target
                                 for each example in the minibatcg, (B,)
            alphas_gt (tensor): attention map. (B, T, U).
            phis_gt (tensor): shift labels. (B, T, U).
        Returns:
            (float): weighted mel loss.
        """
        device = acts.device
        B, T, U, H = acts.size()

        # split mel (B, T, U, n_mels) & phi (B, T, U).
        mels, phis = acts[..., :-1], nn.Sigmoid()(acts[..., -1])

        # non padding mask (B, T, U)
        non_pad_mask, maskT, maskU = self.get_mask(act_lens, label_lens)
        non_pad_mask, maskT, maskU = non_pad_mask.to(device), maskT.to(device), maskU.to(device)

        # diag mask (B, T, U)
        diag_soft_mask = self.get_diag_soft(alphas_gt, self.tau).to(device)
        mask = non_pad_mask & diag_soft_mask

        # compute alphas (B, T, U) & alpha diag cumsum (B*n), n is the num. of diag
        alphas, diag_cumsum = self.get_alphas(phis, mask)

        # mel loss (B, T, U, H-1)
        labels = labels.unsqueeze(1).repeat(1, T, 1, 1)
        mel_loss = nn.L1Loss(reduction='none')(mels, labels)

        # weighted mel loss, (B, T, U, H-1) -> (B, T, U)
        weighted_mel_loss = (alphas * (1.-phis)).unsqueeze(-1) * mel_loss
        weighted_mel_loss = (weighted_mel_loss * mask.unsqueeze(-1)).sum() / (maskU.sum() * (H-1))

        # alpha diag cumsum constrain
        diag_cumsum_mask = (~self.make_pad_mask(act_lens+label_lens-1, maxlen=T+U-1)).float().to(device)
        diag_cumsum_constrain = nn.MSELoss(reduction='none')(diag_cumsum, diag_cumsum.new_ones(diag_cumsum.size()))
        diag_cumsum_constrain = (diag_cumsum_constrain * diag_cumsum_mask).sum() / diag_cumsum_mask.sum()

        # alpha diag constrain
        #diag_constrain = - (alphas * mask).sum() / (mask).sum()
        diag_constrain = (alphas * ~mask).sum() / (~mask).sum()

        # ref mel loss
        ref_mel_loss = (alphas_gt * (1.-phis_gt)).unsqueeze(-1) * mel_loss.detach()
        ref_mel_loss = (ref_mel_loss * mask.unsqueeze(-1)).sum() / (maskU.sum() * (H-1))

        return weighted_mel_loss, diag_cumsum_constrain, diag_constrain, ref_mel_loss


    def make_pad_mask(self, target_lens, maxlen=None):
        """    
        Examples:
            lengths = [5, 3, 2]
            make_pad_mask(lengths)
            masks = [[0, 0, 0, 0 ,0],
                     [0, 0, 0, 1, 1],
                     [0, 0, 1, 1, 1]]
        Return:
            bool.
        """
        if not isinstance(target_lens, list):
            target_lens = target_lens.tolist()

        bs = int(len(target_lens))
        if maxlen is None:
            maxlen = int(max(target_lens))

        seq_range = torch.arange(0, maxlen, dtype=torch.int64)  # (maxlen,)
        seq_range_expand = seq_range.unsqueeze(0).expand(bs, -1)  # (B, maxlen)
        seq_length_expand = seq_range_expand.new(target_lens).unsqueeze(-1) # (B, 1)
        mask = seq_range_expand >= seq_length_expand

        return mask

    def get_mask(self, act_lens, label_lens):
        """non padding mask"""
        non_pad_mask_T = ~self.make_pad_mask(act_lens)   # (B, T)
        non_pad_mask_U = ~self.make_pad_mask(label_lens)  # (B, U)
        non_pad_mask = (non_pad_mask_T.unsqueeze(-1) & non_pad_mask_U.unsqueeze(-2))

        return non_pad_mask, non_pad_mask_T, non_pad_mask_U

    def get_diag_soft(self, alphas_gt, tau=1):
        """Diagonal-like matrix """
        def shift_alphas(alphas, tau):
            alphas_ = alphas.clone()
            for i in range(1, tau+1):  # move left
                alphas_ += F.pad(alphas, (0,i))[...,i:]
            for i in range(1, tau+1):
                alphas_ += F.pad(alphas, (i,0))[...,:-i]
            return alphas_

        return shift_alphas(alphas_gt, tau).bool()


