import torch
import torch.nn.functional as F


def get_diag(x, diagonal=0):
    B, T, U = x.size()
    n = diagonal + (U-1)
    assert n >= 0 and n < T+U-1, \
           f"diagonal out ot bound! diagonal should belong to [{-(U-1)}, {T})"

    start = abs(U-n-1)
    length = min(T, n+1 if n < U else T-(n-U+1))
    index = torch.arange(start, start+length).long().view(1,-1).to(x.device)
    index = index.transpose(0,1) if n < U else index    
    dim = 2 if n < U else 1

    return x.gather(dim=dim, index=index.unsqueeze(0).expand(B, -1, -1))

def get_anti_diag(x, diagonal=0):
    x = x.flip(dims=[-1])
    diag = get_diag(x, diagonal)
    return diag

def get_all_anti_diag(x):
    """
    Args:
        x(tensor): alphas-like, (B, T, U).
    Return:
        tensor: one subject to, (B, T+U).
    """
    B, T, U = x.size()
    all_anti_diag = list()
    for diagonal in range(1-U, T):
        item = get_anti_diag(x, diagonal).view(B,-1).sum(-1, keepdim=True)  # (B, 1)
        all_anti_diag += [item]
    all_anti_diag = torch.cat(all_anti_diag, dim=-1)    # (B, T+U) 

    return all_anti_diag

def make_pad_mask(target_lens, maxlen=None):
    """    
    Examples:
        lengths = [5, 3, 2]
        make_pad_mask(lengths)
        masks = [[0, 0, 0, 0 ,0],
                 [0, 0, 0, 1, 1],
                 [0, 0, 1, 1, 1]]
    Return:
        bool.
    """
    if not isinstance(target_lens, list):
        target_lens = target_lens.tolist()

    bs = int(len(target_lens))
    if maxlen is None:
        maxlen = int(max(target_lens))

    seq_range = torch.arange(0, maxlen, dtype=torch.int64)  # (maxlen,)
    seq_range_expand = seq_range.unsqueeze(0).expand(bs, -1)  # (B, maxlen)
    seq_length_expand = seq_range_expand.new(target_lens).unsqueeze(-1) # (B, 1)
    mask = seq_range_expand >= seq_length_expand

    return mask


def left_and_down(alpha, phi):
    B, T, U = alpha.size()
    pad_count = T-1
    neg_phi = 1.-phi

    alpha = F.pad(alpha, (pad_count, pad_count))
    phi = F.pad(phi, (pad_count, pad_count))
    neg_phi = F.pad(neg_phi, (pad_count, pad_count))

    _, T_, U_ = alpha.size()
    for diagonal in range(-(U-1) - pad_count, 0):
        n = diagonal + (U_-1)
        start = abs(U_-n-1)
        length = min(T_, n+1 if n < U_ else T_-(n-U_+1))
        index = torch.arange(start, start+length).long().view(1,-1).to(phi.device)
        index = index.transpose(0,1) if n < U_ else index    
        dim = 2 if n < U_ else 1
        diag_alpha = alpha.gather(dim=dim, index=index.unsqueeze(0).expand(B, -1, -1), sparse_grad=False)
        diag_phi = phi.gather(dim=dim, index=index.unsqueeze(0).expand(B, -1, -1), sparse_grad=False)
        diag_neg_phi = neg_phi.gather(dim=dim, index=index.unsqueeze(0).expand(B, -1, -1), sparse_grad=False)

        n = n + 1
        start = abs(U_-n-1)
        length = min(T_, n+1 if n < U_ else T_-(n-U_+1))
        index = torch.arange(start, start+length).long().view(1,-1).to(phi.device)
        index = index.transpose(0,1) if n < U_ else index    
        dim = 2 if n < U_ else 1

        src = diag_alpha * diag_neg_phi
        alpha = alpha.clone().scatter_(dim=dim, index=index.unsqueeze(0).expand(B, -1, -1), src=src)

        src = diag_alpha * diag_phi
        src = F.pad(src, (0,0,1,0))[:,:-1,:] if n < U_ else F.pad(src, (1,0))[:,:,:-1]
        alpha = alpha.scatter_add_(dim=dim, index=index.unsqueeze(0).expand(B, -1, -1), src=src)

    return alpha[..., T-1:-(T-1)]


def flip_diag(alpha, phi):
    alpha = alpha.flip(dims=[-1])
    phi = phi.flip(dims=[-1])
    alpha_ = left_and_down(alpha, phi)
    alpha_ = alpha_.flip(dims=[-1])
    return alpha_


class ComputerAlphas(torch.nn.Module):
    """docstring for ComputerAlphas"""
    def __init__(self):
        super(ComputerAlphas, self).__init__()
        
    def forward(self, phis, mask=None):
        """
        Args:
            phis(tensor): prob of shifting next phoneme, (B, T, U).
            mask(bool tensor): non padding mask, (B, T, U).
        Returns:
            alphas(tensor): forward algrism, (B, T, U)
            comsum_st: s.t. comsum to one, (B*n,)
        """
        B, T, U = phis.size()
        alphas = phis.new_ones(phis.size(), requires_grad=True)
        alphas = flip_diag(alphas, phis)
        if mask is not None:
            alphas_ = alphas.clone() * mask
        comsum_st = get_all_anti_diag(alphas_)

        return alphas, comsum_st

        

if __name__ == "__main__":
    torch.manual_seed(1)
    B, T, U = 1, 3, 5
    alpha = torch.ones((B, T, U), requires_grad=True, device='cpu')
    phi = torch.rand((B, T, U), requires_grad=True, device='cpu')
    print(f"alpha: {alpha}")
    print(f"phi: {phi}")
    print(f"1-phi: {1.-phi}")
    alpha_ = flip_diag(alpha, phi)
    print(f"alpha_: {alpha_}")
    print(f"alpha_ anti diag: {get_anti_diag(alpha_, -4)}")
    print(f"alpha_ anti diag: {get_anti_diag(alpha_, -3)}")
    print(f"get_all_anti_diag: {get_all_anti_diag(alpha_)}")
