import k2
import torch
import logging


def filter_uneven_sized_batch(batch: dict, allowed_max_frames: int):
    """For the uneven-sized batch, the total duration after padding would possibly
    cause OOM. Hence, for each batch, which is sorted in descending order by length,
    we simply drop the last few shortest samples, so that the retained total frames
    (after padding) would not exceed the given allow_max_frames.

    Args:
      batch:
        A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()`
        for the content in it.
      allowed_max_frames:
        The allowed max number of frames in batch.
    """
    features = batch["inputs"]
    supervisions = batch["supervisions"]

    N, T, _ = features.size()
    assert T == supervisions["num_frames"].max(), (T, supervisions["num_frames"].max())
    kept_num_utt = allowed_max_frames // T

    if kept_num_utt >= N or kept_num_utt == 0:
        return batch

    # Note: we assume the samples in batch is sorted descendingly by length
    # logging.info(
    #     f"Filtering uneven-sized batch, original batch size is {N}, "
    #     f"retained batch size is {kept_num_utt}."
    # )
    batch["inputs"] = features[:kept_num_utt]
    for k, v in supervisions.items():
        assert len(v) == N, (len(v), N)
        batch["supervisions"][k] = v[:kept_num_utt]

    return batch


def concat(ragged: k2.RaggedTensor, value: int, direction: str) -> k2.RaggedTensor:
    """Prepend a value to the beginning of each sublist or append a value.
    to the end of each sublist.

    Args:
      ragged:
        A ragged tensor with two axes.
      value:
        The value to prepend or append.
      direction:
        It can be either "left" or "right". If it is "left", we
        prepend the value to the beginning of each sublist;
        if it is "right", we append the value to the end of each
        sublist.

    Returns:
      Return a new ragged tensor, whose sublists either start with
      or end with the given value.

    >>> a = k2.RaggedTensor([[1, 3], [5]])
    >>> a
    [ [ 1 3 ] [ 5 ] ]
    >>> concat(a, value=0, direction="left")
    [ [ 0 1 3 ] [ 0 5 ] ]
    >>> concat(a, value=0, direction="right")
    [ [ 1 3 0 ] [ 5 0 ] ]

    """
    dtype = ragged.dtype
    device = ragged.device

    assert ragged.num_axes == 2, f"num_axes: {ragged.num_axes}"
    pad_values = torch.full(
        size=(ragged.tot_size(0), 1),
        fill_value=value,
        device=device,
        dtype=dtype,
    )
    pad = k2.RaggedTensor(pad_values)

    if direction == "left":
        ans = k2.ragged.cat([pad, ragged], axis=1)
    elif direction == "right":
        ans = k2.ragged.cat([ragged, pad], axis=1)
    else:
        raise ValueError(
            f'Unsupported direction: {direction}. " \
            "Expect either "left" or "right"'
        )
    return ans

