"""Manipulation of matrices in the TT format"""

import tensorly as tl
from ._batched_tensordot import tensordot


def tt_matrix_to_tensor(tt_matrix):
    """Returns the full tensor whose TT-Matrix decomposition is given by 'factors'

        Re-assembles 'factors', which represent a tensor in TT-Matrix format
        into the corresponding full tensor

    Parameters
    ----------
    factors: list of 4D-arrays
              TT-Matrix factors (known as core) of shape (rank_k, left_dim_k, right_dim_k, rank_{k+1})

    Returns
    -------
    output_tensor: ndarray
                   tensor whose TT-Matrix decomposition was given by 'factors'
    """
    # Each core is of shape (rank_left, size_in, size_out, rank_right)
    _, in_shape, out_shape, _ = zip(*(tl.shape(f) for f in tt_matrix))
    ndim = len(in_shape)

    # Intertwine the dims
    # full_shape = in_shape[0], out_shape[0], in_shape[1], ...
    full_shape = sum(zip(*(in_shape, out_shape)), ())
    order = list(range(0, ndim * 2, 2)) + list(range(1, ndim * 2, 2))

    for i, factor in enumerate(tt_matrix):
        if not i:
            res = factor
        else:
            res = tensordot(res, factor, ([-1], [0]))

    return tl.transpose(tl.reshape(res, full_shape), order)
