import torch
from torch.func import vjp

from core import MonoidReduce, check


def init(a, b):
    ap, _ = a
    B = ap.shape[0]
    p = ap.new_full((B,), float('-inf'))
    n = ap.new_zeros((B,))
    return p, n

def islice(parts):
    return zip(*(part.chunk(2) for part in parts))

def proj_fold(a, b):
    ap, at = a
    bc, bi = b
    logits = ap @ bc.t()
    
    p = torch.logsumexp(logits, dim=1)
    n = ((at[:, None] == bi[None, :]) * logits).sum(1)
    return p, n

def binary_reduce(x, y):
    xp, xn = x
    yp, yn = y
    p = torch.logaddexp(xp, yp)
    n = xn + yn
    return p, n

def proj_fold_bwd(a, b, p, gp):
    pp, pn = p
    gpp, gpn = gp

    ap, at = a
    bc, bi = b
    # recompute.
    logits = ap @ bc.t()
    mask = (at[:, None] == bi[None, :])
    lp = logits.logsumexp(dim=1)
    # Use local gradient theorem to get d p^i / d p^{ij}
    glp = gpp * (lp - pp).exp()
    gln = gpn

    # Compute d p^{ij} / d [a, b]
    gl = glp[:, None] * (logits - lp[:, None]).exp() + gln [:, None] * mask
    gap = gl @ bc
    gbc = gl.t() @ ap

    # a_t and c_i are indices, no grads.
    return (gap, None), (gbc, None)

XEntropy = MonoidReduce(
    'XEntropy', 
    init=init,
    avals=2,
    bvals=2,
    islice=islice,
    jslice=islice, 
    proj_fold=proj_fold, 
    proj_fold_bwd=proj_fold_bwd, 
    binary_reduce=binary_reduce)

def xentropy(p, c, t):
    cix = torch.arange(c.shape[0], device=c.device)
    p, n = XEntropy.apply(p, t, c, cix)
    return p - n

def naive_xentropy(p, c, t):
    a = p @ c.t()
    return torch.nn.functional.cross_entropy(a, t, reduction='none')
    
if __name__ == '__main__':
    L, D, R = 1024*2, 128, 1024*2
    dtype = torch.double
    p = torch.randn(L, D, requires_grad=True, dtype=dtype)
    c = torch.randn(R, D, requires_grad=True, dtype=dtype)
    t = torch.randint(R, (L,))

    mock = torch.randn(L, dtype=dtype)
    check(xentropy, naive_xentropy, (p, c, t), mock)
