import torch
from torch.func import vjp

from core import MonoidReduce, check


def init(a, b):
    ap, _ = a
    B = ap.shape[0]
    p = ap.new_full((B,), float('-inf'))
    q = ap.new_full((B,), float('-inf'))
    n = ap.new_zeros((B,))
    return p, q, n

def islice(parts):
    return zip(*(part.chunk(8) for part in parts))

def proj_fold(a, b):
    ap, aq = a
    bc, bd = b
    # L x R
    pl = ap @ bc.t()
    # L x R
    ql = aq @ bd.t()
    
    # p, q, n: L
    p = torch.logsumexp(pl, dim=1)
    q = torch.logsumexp(ql, dim=1)
    n = ((ql - q[:, None]).exp() * pl).sum(1)
    return p, q, n

def binary_reduce(x, y):
    xp, xq, xn = x
    yp, yq, yn = y
    p = torch.logaddexp(xp, yp)
    q = torch.logaddexp(xq, yq)
    n = xn * (xq - q).exp() + yn * (yq - q).exp()
    return p, q, n

def proj_fold_bwd(a, b, p, gp):
    (lp, lq, ln), local_vjp = vjp(proj_fold, a, b)
    
    p, q, n = p
    gp, gq, gn = gp 
    # Compute d p^{ij} / d [a, b].

    lgp = gp * (lp - p).exp()
    lgq = (gq + gn * (ln - n)) * (lq - q).exp()
    lgn = gn * (lq - q).exp()
    lg = (lgp, lgq, lgn)
    
    # Run local vjp.
    return local_vjp(lg)

XEntropy2 = MonoidReduce(
    'XEntropy2', 
    init=init,
    avals=2,
    bvals=2,
    islice=islice,
    jslice=islice, 
    proj_fold=proj_fold, 
    proj_fold_bwd=proj_fold_bwd, 
    binary_reduce=binary_reduce)

def xentropy2(p, c, q, d):
    p, _, n = XEntropy2.apply(p, q, c, d)
    return p - n

def naive_xentropy2(p, c, q, d):
    p = p @ c.t()
    q = (q @ d.t()).softmax(dim=1)
    return torch.nn.functional.cross_entropy(p, q, reduction='none')
    
if __name__ == '__main__':
    L, D, R = 1024*2, 128, 1024*2
    dtype = torch.double
    p = torch.randn(L, D, requires_grad=True, dtype=dtype)
    c = torch.randn(R, D, requires_grad=True, dtype=dtype)
    q = torch.randn(L, 2*D, requires_grad=True, dtype=dtype)
    d = torch.randn(R, 2*D, requires_grad=True, dtype=dtype)

    mock = torch.randn(L)
    check(xentropy2, naive_xentropy2, (p, c, q, d), mock)
