import torch

from core import MonoidReduce, check

def init(a, b):
    aq, = a
    _, bv = b
    B = aq.shape[0]
    D = bv.shape[1]
    
    z = aq.new_full((B,), float('-inf'))
    v = bv.new_zeros((B, D))
    return z, v

def islice(parts):
    return zip(*(part.chunk(8) for part in parts))

def proj_fold(a, b):
    aq, = a
    bk, bv = b
    logits = aq @ bk.t()
    hi = logits.max(1).values
    logits = logits - hi[:, None]
    rest = logits.exp().sum(1).log()
    
    z = hi + rest
    v = (logits - rest[:, None]).exp() @ bv
    return z, v

def binary_reduce(x, y):
    xz, xv = x
    yz, yv = y
    
    z = torch.logaddexp(xz, yz)
    v = (xz - z).exp()[:, None] * xv + (yz - z).exp()[:, None] * yv
    return z, v

def proj_fold_bwd(a, b, p, gp):
    # ignore gz, as it is only used for aggregation.
    aq, bk, bv, pz, pv, _, gv = *a, *b, *p, *gp
    # Recompute relevant stuff
    logits = aq @ bk.t()
    ws = (logits - pz[:, None]).exp()
    # calculate gradients.
    # (gv * pv).sum(1) could be precomputed before broadcast,
    gbv = ws.t() @ gv
    glogits = (gv @ bv.t() - (gv * pv).sum(1)[:, None]) * ws
    gaq = glogits @ bk
    gbk = glogits.t() @ aq

    return (gaq,), (gbk, gbv)

Attention = MonoidReduce('Attention', 
                   init=init,
                   avals=1, 
                   bvals=2, 
                   islice=islice, 
                   jslice=islice,
                   proj_fold=proj_fold, 
                   proj_fold_bwd=proj_fold_bwd, 
                   binary_reduce=binary_reduce
                   )

def attention(q, k, v):
    z, v = Attention.apply(q, k, v)
    return v.to(k.dtype)

def naive_attention(q, k, v):
    a = q @ k.t()
    return a.softmax(dim=1) @ v

if __name__ == '__main__':
    L, H = 8*1024, 256
    R, D = L, H
    dtype = torch.double
    q = torch.randn(L, H, requires_grad=True, dtype=dtype)
    k = torch.randn(R, H, requires_grad=True, dtype=dtype)
    v = torch.randn(R, D, requires_grad=True, dtype=dtype)

    mock = torch.randn(L, D, dtype=dtype)

    check(attention, naive_attention, (q, k, v), mock)
