import torch
from torch.utils.data import DataLoader, random_split


def get_train_val_test_datasets(dataset, train_ratio, val_ratio):
    assert (train_ratio + val_ratio) <= 1
    train_size = int(len(dataset) * train_ratio)
    val_size = int(len(dataset) * val_ratio)
    test_size = len(dataset) - train_size - val_size

    train_set, val_set, test_set = random_split(
        dataset, [train_size, val_size, test_size]
    )
    return train_set, val_set, test_set


def get_train_val_test_loaders(
    dataset, train_ratio, val_ratio, train_batch_size, val_test_batch_size, num_workers
):
    train_set, val_set, test_set = get_train_val_test_datasets(
        dataset, train_ratio, val_ratio
    )

    train_loader = DataLoader(
        train_set, train_batch_size, shuffle=True, num_workers=num_workers
    )
    val_loader = DataLoader(
        val_set, val_test_batch_size, shuffle=False, num_workers=num_workers
    )
    test_loader = DataLoader(
        test_set, val_test_batch_size, shuffle=False, num_workers=num_workers
    )

    return train_loader, val_loader, test_loader


def get_data_iterator(iterable):
    """Allows training with DataLoaders in a single infinite loop:
    for i, data in enumerate(inf_generator(train_loader)):
    """
    iterator = iterable.__iter__()
    while True:
        try:
            yield iterator.__next__()
        except StopIteration:
            iterator = iterable.__iter__()
