import torch
import math

n=1000
d = 800
Q = torch.randn((n,d), requires_grad=True)
K = torch.randn((n,d), requires_grad=True)
V = torch.randn((n,d), requires_grad=True)

print(- 100*torch.triu(torch.ones(n,n), diagonal=1))

S = torch.mm(Q, K.t()) / math.sqrt(d) - 100*torch.triu(torch.ones(n,n), diagonal=1)
print(S)
P = torch.softmax(S,dim=1)

print()

O = torch.mm(P, V)
print(O.size(), (O**2).sum(dim=1)/d)
G = torch.randn((n,d))
L = (O*G).sum()
L.backward()
print("V grad: ", (V.grad**2).mean())
print("K grad: ", (K.grad**2).mean())
print("Q grad: ", (Q.grad**2).mean())

#print(P)
