import numpy as np
import torch


def messages_to_bytes(messages: np.ndarray, trim=True):
    # TODO: This can be sped up by calling tobytes on the entire array.
    message_shape = messages.shape[:-1]
    bytemessages = np.zeros(message_shape, dtype="|S200")
    it = np.ndindex(message_shape)
    for i in it:
        bytemessages[i] = message_to_bytes(messages[i], trim=trim)
    return bytemessages


def message_to_bytes(msg: np.ndarray, trim=True):
    b = msg.tobytes()
    if not trim:
        return b
    try:
        padding_index = b.index(b"\x00")
        return b[:padding_index]
    except ValueError:
        return b
