from gluonts.dataset.dataset_slice import slice_datasets
from gluonts.dataset.repository.datasets import get_dataset
from gluonts.dataset.stat import calculate_dataset_statistics


def test_dataset_slice():
    dataset = get_dataset("m4_hourly")
    sliced_dataset = slice_datasets(
        dataset, max_num_timeseries=10, max_num_observations=100
    )
    sliced_stats_train = calculate_dataset_statistics(sliced_dataset.train)
    assert sliced_stats_train.mean_target_length == 100.0
    assert sliced_stats_train.num_time_series == 10
    assert calculate_dataset_statistics(
        sliced_dataset.test
    ) == calculate_dataset_statistics(dataset.test)

    sliced_dataset = slice_datasets(
        dataset,
        max_num_timeseries=10,
        max_num_observations=100,
        slice_test_target=True,
    )
    sliced_stats_train = calculate_dataset_statistics(sliced_dataset.train)
    assert sliced_stats_train.mean_target_length == 100.0
    assert sliced_stats_train.num_time_series == 10

    sliced_stats_test = calculate_dataset_statistics(sliced_dataset.test)
    assert sliced_stats_test.mean_target_length == 100.0
    assert sliced_stats_test.num_time_series == 10
