from df_attn import cum_by_inv, cum_by_inv_backward
from test_attn import clone_with_grad, check_close
import torch


torch.set_default_device('cuda')
torch.set_default_dtype(torch.float)

bs = 1
t = 256
h = 128
w = torch.randn(bs, t, t)
a = torch.randn(bs, t, h)
dy = torch.randn(bs, t, h)

w1 = clone_with_grad(w)
w2 = clone_with_grad(w)
a1 = clone_with_grad(a)
a2 = clone_with_grad(a)

y1 = cum_by_inv(a1, w1)
y1.backward(dy)

y2 = cum_by_inv(a2, w2)
da, dw = cum_by_inv_backward(dy, w2, y2)

check_close(a1.grad, da, name='da')
check_close(w1.grad, dw, name='dw', print_details=True)
