"""
Functions for preprocessing data.
"""

import utils


def align_shapes(data, comp):
    """ Transforms both inputs to same shape, input can be torch or np.
    :param data: shape (K x N x D) or (N x D)
    :param comp: complement, i.e. paired data with shape N, e.g. labels or
    indices
    :return: data and complement with shape (K*N),...
    """
    data = utils.to_torch(data)
    if comp is not None:
        comp = utils.to_torch(comp)
        # ensure alignment
        assert data.size(-2) == comp.size(0), 'no alignment'
        if len(data.size()) > 2:
            # data has shape K x N x D
            comp = comp[None].repeat(data.size(0), 1)  # K x N
        else:
            assert len(data.size()) == 2
        comp = comp.view(-1)
    data = data.view(-1, data.size(-1))
    return utils.to_np(data), utils.to_np(comp)


def _add_linebreak(string, threshold):
    if len(string) > threshold:
        # Replace whitespace before threshold with linebreak.
        idx = string[:threshold][::-1].index(' ')
        string = list(string)
        string[threshold - idx - 1] = '\n'
        string = ''.join(string)
    else:
        # Add linebreak at the end
        string += '\n'
    return string


def process_caption(caption, threshold=40):
    """ Split caption across lines if caption is too long
    :param threshold: number of characters that fit one line
    """
    caption = _add_linebreak(caption, threshold)
    caption = _add_linebreak(caption, threshold * 2)
    return caption
