import torch
import torch.nn.functional as F


def get_mask_from_lengths(lengths, max_len=None):
    lengths = lengths.to(torch.long)
    if max_len is None:
        max_len = torch.max(lengths).item()

    ids = (
        torch.arange(0, max_len)
        .unsqueeze(0)
        .expand(lengths.shape[0], -1)
        .to(lengths.device)
    )
    mask = ids < lengths.unsqueeze(1).expand(-1, max_len)

    return mask


def linear_interpolation(features, seq_len):
    features = features.transpose(1, 2)
    output_features = F.interpolate(
        features, size=seq_len, align_corners=True, mode="linear"
    )
    return output_features.transpose(1, 2)


if __name__ == "__main__":
    import numpy as np

    mask = ~get_mask_from_lengths(torch.from_numpy(np.array([4, 6])))
    import pdb

    pdb.set_trace()
