import torch.utils.data


def split_dataset(ds: torch.utils.data.Dataset, prop: float, direction: str):
    """
    Splits a dataset based on the given index and direction.

    Args:
        ds: The full dataset.
        pro: Proportion to keep.
        direction: "forward" or "backward"

    Returns:
        A torch.utils.data.Subset of the dataset.
    """
    if prop < 0 or prop > 1:
        raise ValueError("Propotion should be in [0, 1]")
    split_idx = int(prop * len(ds))
    if direction == "forward":
        indices = range(0, split_idx)
    elif direction == "backward":
        indices = range(split_idx, len(ds))
    else:
        raise ValueError("direction must be 'forward' or 'backward'")

    ds_subset = torch.utils.data.Subset(ds, indices)
    ds_subset.collate_fn = ds.collate_fn
    return ds_subset
