import torch


def transducer_loss_reference(x, label, f_len, y_len, blank_idx, loss_grad):
    def log_sum_exp(a, b):
        if (a >= b):
            return a + torch.log(1 + torch.exp(b-a))
        else:
            return b + torch.log(1 + torch.exp(a-b))

    def forward_alpha(x, label, f_len, y_len, blank_idx):
        B, T, U, V = x.size()
        acc_t = torch.float32 if x.dtype in [torch.float16, torch.float32] else x.dtype
        alpha = torch.zeros((B, T, U), dtype=acc_t, device=x.device)
        for b in range(B):
            alpha[b, 0, 0] = 0
            for t in range(1, f_len[b]):
                alpha[b, t, 0] = alpha[b, t-1, 0] + x[b, t-1, 0, blank_idx]
            for u in range(1, y_len[b]+1):
                alpha[b, 0, u] = alpha[b, 0, u-1] + x[b, 0, u-1, label[b, u-1]]
            for t in range(1, f_len[b]):
                for u in range(1, y_len[b]+1):
                    curr_ = alpha[b, t-1, u] + x[b, t-1, u, blank_idx]
                    next_ = alpha[b, t, u-1] + x[b, t, u-1, label[b, u-1]]
                    alpha[b, t, u] = log_sum_exp(curr_, next_)
        return alpha

    def forward_beta(x, label, f_len, y_len, blank_idx):
        B, T, U, V = x.shape
        acc_t = torch.float32 if x.dtype in [torch.float16, torch.float32] else x.dtype
        beta = torch.zeros((B, T, U), dtype=acc_t, device=x.device)
        for b in range(B):
            beta[b, f_len[b]-1, y_len[b]] = x[b, f_len[b]-1, y_len[b], blank_idx]
            for t in range(f_len[b]-2, -1, -1):
                beta[b, t, y_len[b]] = beta[b, t+1, y_len[b]] + x[b, t, y_len[b], blank_idx]
            for u in range(y_len[b]-1, -1, -1):
                beta[b, f_len[b]-1, u] = beta[b, f_len[b]-1, u+1] + x[b, f_len[b]-1, u, label[b, u]]
            for t in range(f_len[b]-2, -1, -1):
                for u in range(y_len[b]-1, -1, -1):
                    curr_ = beta[b, t+1, u] + x[b, t, u, blank_idx]
                    next_ = beta[b, t, u+1] + x[b, t, u, label[b, u]]
                    beta[b, t, u] = log_sum_exp(curr_, next_)
        return beta

    def backward(x, label, f_len, y_len, alpha, beta, loss_grad, blank_idx):
        grad = torch.zeros_like(x)
        B, T, U, V = x.size()
        for b in range(B):
            common_factor = torch.log(loss_grad[b]) + alpha - beta[b, 0, 0]
            # next
            for u in range(y_len[b]):
                grad[b, :f_len[b], u, label[b, u]] = -torch.exp(common_factor[b, :f_len[b], u]
                                                        + beta[b, :f_len[b], u+1]
                                                        + x[b, :f_len[b], u, label[b, u]])

            # current
            grad[b, :f_len[b]-1, :y_len[b]+1, blank_idx] \
                = -torch.exp(common_factor[b, :f_len[b]-1, :y_len[b]+1]
                    + beta[b, 1:f_len[b], :y_len[b]+1]
                    + x[b, :f_len[b]-1, :y_len[b]+1, blank_idx])

            grad[b, f_len[b]-1, y_len[b], blank_idx] = -torch.exp(common_factor[b, f_len[b]-1, y_len[b]]
                                                         + x[b, f_len[b]-1, y_len[b], blank_idx])

        return grad

    x_log = torch.nn.functional.log_softmax(x, dim=-1)
    alpha = forward_alpha(x_log, label, f_len, y_len, blank_idx)
    beta = forward_beta(x_log, label, f_len, y_len, blank_idx)
    grad = backward(x_log, label, f_len, y_len, alpha, beta,
                        loss_grad, blank_idx)
    x_log.backward(grad)
    loss = -beta[:, 0, 0]
    loss = loss.to(x.dtype)
    return alpha, beta, x.grad, loss


def transducer_joint_reference(f, g, h_grad, f_len, g_len, pack_output, relu, dropout,
                                dropout_prob=0, mask=None):
    if dropout and mask == None:
        raise NotImplementedError("mask needs to supplied to test dropout.")
    B, T, H = f.size()
    U = g.size(1)
    f_expand = f.unsqueeze(dim=2)
    g_expand = g.unsqueeze(dim=1)
    h = f_expand + g_expand
    if relu:
        h = torch.nn.functional.relu(h)
    if dropout:
        h *= mask
        scale = 1/(1-dropout_prob)
        h *= scale
    h.backward(h_grad)

    if pack_output == False:
        # intentionally set don't-care region to -1 to test if transducer joint
        # write these regions to avoid NaN and inf
        for b in range(B):
            h[b, f_len[b]:] = -1
            h[b, :, g_len[b]:] = -1

        return h, f.grad, g.grad

    # packing
    list_to_pack = []
    for b in range(B):
        list_to_pack.append(h[b, :f_len[b], :g_len[b], :].reshape(-1, H))
    h_packed = torch.cat(list_to_pack)
    return h_packed, f.grad, g.grad
