from itertools import islice
from typing import Optional, Iterable

from gluonts.dataset.common import TrainDatasets, DataEntry


def slice_dataset(
    dataset: Iterable[DataEntry],
    max_num_timeseries: Optional[int] = None,
    max_num_observations: Optional[int] = None,
):
    if max_num_timeseries is not None:
        # the conversion for list is only needed to compute statistics since
        # `compute_dataset_statistics` calls `len(ts_dataset)`
        dataset = list(islice(dataset, max_num_timeseries))

    def slice_entry(entry):
        target_size = len(entry["target"])
        if max_num_observations > target_size:
            return entry
        else:
            # max_num_observations <= target_size
            new_entry = entry.copy()
            # take the first observations, as such we do not need to update the start date
            new_entry["target"] = new_entry["target"][-max_num_observations:]
            offset = target_size - max_num_observations
            new_entry["start"] = (
                new_entry["start"] + new_entry["start"].freq * offset
            )
            return new_entry

    if max_num_observations is not None:
        dataset = [slice_entry(x) for x in dataset]

    return dataset


def slice_datasets(
    datasets,
    max_num_timeseries: Optional[int] = None,
    max_num_observations: Optional[int] = None,
    slice_test_target: bool = False,
):
    return TrainDatasets(
        metadata=datasets.metadata,
        train=slice_dataset(
            datasets.train, max_num_timeseries, max_num_observations
        ),
        test=slice_dataset(
            datasets.test, max_num_timeseries, max_num_observations
        )
        if slice_test_target
        else datasets.test,
    )
