import torch
from cola.ops import Tridiagonal
import tensorly as tl
from tensorly.decomposition import matrix_product_state

tl.set_backend('pytorch')
torch.manual_seed(21)

N = 12
alpha = torch.randn(size=(N - 1, 1))
beta = torch.randn(size=(N, 1))
T = Tridiagonal(alpha=alpha, beta=beta, gamma=torch.clone(alpha))
xnp = T.xnp
tensor = tl.tensor(torch.clone(T.to_dense()))
tensor = tensor.reshape((4, 3, 3, 2, 2))
# factors = matrix_product_state(tensor, rank=[1, 12, 12, 12, 12, 1])
factors = matrix_product_state(tensor, rank=[1, 12, 12, 8, 4, 1])
approx = tl.tt_to_tensor(factors)

print(xnp.norm(approx - tensor))

i, j, k, r, p = 0, 2, 1, 1, 1
print(tensor[i, j, k, r, p])
aux = factors[0][:, i, :] @ factors[1][:, j, :] @ factors[2][:, k, :] @ factors[3][:, r, :] @ factors[4][:, p, :]
print(aux)
