import tensorly as tl


# TODO : add batched_modes as in batched_tensor_dot?
def batched_outer(tensors):
    """Returns a generalized outer product of the two tensors

    Parameters
    ----------
    tensor1 : tensor
        of shape (n_samples, J1, ..., JN)
    tensor2 : tensor
        of shape (n_samples, K1, ..., KM)

    Returns
    -------
    outer product of tensor1 and tensor2
        of shape (n_samples, J1, ..., JN, K1, ..., KM)
    """
    for i, tensor in enumerate(tensors):
        if i:
            shape = tl.shape(tensor)
            size = len(shape) - 1

            n_samples = shape[0]

            if n_samples != shape_res[0]:
                raise ValueError(
                    f"Tensor {i} has a batch-size of {n_samples} but those before had a batch-size of {shape_res[0]}, "
                    "all tensors should have the same batch-size."
                )

            shape_1 = shape_res + (1,) * size
            shape_2 = (n_samples,) + (1,) * size_res + shape[1:]

            res = tl.reshape(res, shape_1) * tl.reshape(tensor, shape_2)
        else:
            res = tensor

        shape_res = tl.shape(res)
        size_res = len(shape_res) - 1

    return res


def outer(tensors):
    """Returns a generalized outer product of the two tensors

    Parameters
    ----------
    tensor1 : tensor
        of shape (J1, ..., JN)
    tensor2 : tensor
        of shape (K1, ..., KM)

    Returns
    -------
    outer product of tensor1 and tensor2
        of shape (J1, ..., JN, K1, ..., KM)
    """
    for i, tensor in enumerate(tensors):
        if i:
            shape = tl.shape(tensor)
            s1 = len(shape)

            shape_1 = shape_res + (1,) * s1
            shape_2 = (1,) * sres + shape

            res = tl.reshape(res, shape_1) * tl.reshape(tensor, shape_2)
        else:
            res = tensor

        shape_res = tl.shape(res)
        sres = len(shape_res)

    return res
