import torch
import cola

torch.manual_seed(21)

N = 3
N2 = N**2

v = torch.arange(N2)
aux = v.reshape((3, 3))
breakpoint()
v = torch.arange(8) + 1
aux = v.reshape(2, -1)
aux = v.reshape((2, 2, 2))
aux = v.reshape((3, 3))
aux = torch.moveaxis(aux, 1, 0)
# aux = torch.moveaxis(aux, 0, 1)
K1 = torch.randn((N, N))
K2 = torch.randn((N, N))
K = cola.ops.Kronecker(K1, K2)
print(K @ v)
