import abc
import functools
from collections import Counter
import os
from typing import Tuple, Optional, List, Any, Union, Dict, Callable

import torch

import numpy as np

from data.dataset import DatasetMixin
from utils.utils import ceil_div, getitem, generate_intervals


class Transform(abc.ABC):
    """
    Base class for all transforms. A Transform processes one (or several) data points and outputs them. Transforms can
    be chained in a pull-based pipeline.
    """
    def __init__(self, parent: Optional['Transform']):
        """
        Initialize a transform.

        :param parent: Another transform which is used as the data source for this transform. CVan be none in the case
        of a source.
        """
        self.parent = parent

    def get_datapoint(self, item: int) -> Tuple[Tuple[torch.Tensor, ...], Tuple[torch.Tensor, ...]]:
        """
        Returns a datapoint (in our case this is a sequence) from this transform.

        :param item: Must be 0<=item<len(self)
        :return: A datapoint of the form (inputs, targets), where inputs and targets are tuples of tensors.
        """
        if not (0 <= item < len(self)):
            raise IndexError

        return self._get_datapoint_impl(item)

    @abc.abstractmethod
    def _get_datapoint_impl(self, item: int) -> Tuple[Tuple[torch.Tensor, ...], Tuple[torch.Tensor, ...]]:
        """
        This should be implemnted by every subclass to produce a datapoint.

        :param item: Index of the datapoint to fetch
        :return: A datapoint of the form (inputs, targets), where inputs and targets are tuples of tensors.
        """
        raise NotImplementedError

    def __len__(self) -> Optional[int]:
        """
        This should return the number of available sequences after the transformation.
        """
        return len(self.parent) if self.parent is not None else None

    @property
    def seq_len(self) -> Union[int, List[int]]:
        """
        This should return the length of each time series. If the time series have different lengths, the return
        value should be a list that contains the length of each sequence. If all sequences are of equal length,
        this should return an int.
        """
        return self.parent.seq_len if self.parent is not None else None

    @property
    def num_features(self) -> Union[int, Tuple[int, ...]]:
        """
        Number of features of each datapoint. This can also be a tuple if the data has more than one feature dimension.
        """
        return self.parent.num_features if self.parent is not None else None


class ReconstructionTargetTransform(Transform):
    def __init__(self, parent: Transform, replace_labels: bool = False):
        """
        Adds the current inputs as targets for reconstruction objectives.

        :param parent: Another transform which is used as the data source for this transform.
        :param replace_labels: Whether the original labels should be replaced by the reconstruction target. If False, the reconstruction target will be added to the tuple of original labels.
        """
        super(ReconstructionTargetTransform, self).__init__(parent)
        self.replace_labels = replace_labels

    def _get_datapoint_impl(self, item: int) -> Tuple[Tuple[torch.Tensor, ...], Tuple[torch.Tensor, ...]]:
        inputs, targets = self.parent.get_datapoint(item)

        if self.replace_labels:
            return inputs, inputs

        return inputs, targets + inputs


class OneVsRestTargetTransform(Transform):
    def __init__(self, parent: Transform, normal_class: Optional[Any] = None, anomalous_class: Optional[Any] = None,
                 replace_labels: bool = False):
        """
        Transforms multi-class labels into binary labels for anomaly detection. "Normal" data points will have label 0,
        others will have label 1.

        :param parent: Another transform which is used as the data source for this transform.
        :param normal_class: The input class label that should be considered normal and will have label 0 in the output..
        :param anomalous_class: You can also specify an anomalous class that will have label 1. All other labels will be transformed to 0. Note that you cannot specify both normal_class and anomalous_class.
        :param replace_labels: Whether the original labels should be replaced by the transform. If False, the additional labels will be added to the tuple of original labels.
        """
        super(OneVsRestTargetTransform, self).__init__(parent)
        self.replace_labels = replace_labels

        if normal_class is None and anomalous_class is None:
            raise ValueError('Must set either normal_class or anomalous_class!')
        if normal_class is not None and anomalous_class is not None:
            raise ValueError('Cannot specify both normal_class and anomalous_class!')

        self.normal_class = normal_class
        self.anomalous_class = anomalous_class

    def _get_datapoint_impl(self, item: int) -> Tuple[Tuple[torch.Tensor, ...], Tuple[torch.Tensor, ...]]:
        inputs, targets = self.parent.get_datapoint(item)

        if self.normal_class is not None:
            new_targets = tuple(torch.where(target == self.normal_class, 0, 1) for target in targets)
        elif self.anomalous_class is not None:
            new_targets = tuple(torch.where(target == self.anomalous_class, 1, 0) for target in targets)
        else:
            new_targets = targets

        if self.replace_labels:
            return inputs, new_targets

        return inputs, targets + new_targets


class SubsampleTransform(Transform):
    def __init__(self, parent: Transform, subsampling_factor: int, aggregation: str = 'first'):
        """
        Subsample sequences by a specified factor. `subsampling_factor` consecutive datapoints in a sequence will be
        aggregated into one point using the `aggregation` function.

        :param parent: Another transform which is used as the data source for this transform.
        :param subsampling_factor: This specifies the number of consecutive data points that will be aggregated.
        :param aggregation: The function that should be applied to aggregate a window of data points. Can be either 'mean', 'last' or 'first'.
        """
        super(SubsampleTransform, self).__init__(parent)

        self.subsampling_factor = subsampling_factor
        if aggregation == 'mean':
            self.aggregate_fn = functools.partial(torch.mean, dim=1)
        if aggregation == 'last':
            self.aggregate_fn = functools.partial(getitem, item=(slice(None), -1))
        else: # 'first'
            self.aggregate_fn = functools.partial(getitem, item=(slice(None), 0))

    def _process_tensor(self, inp: torch.Tensor) -> torch.Tensor:
        # Input has shape (T, ...). Reshape it to (T//subsampling_factor, subsampling_factor, ...) and apply
        # aggregate on the 2nd axis. We might need to add padding at the end for this to work
        inp_shape = inp.shape
        new_t, rest = divmod(inp_shape[0], self.subsampling_factor)
        if rest > 0:
            # Add padding. We pad with the result of aggregating the last (incomplete) window
            pad_value = self.aggregate_fn(inp[new_t * self.subsampling_factor:inp_shape[0]].unsqueeze(0))
            inp = torch.cat([inp] + (self.subsampling_factor - rest) * [pad_value])
            new_t += 1

        inp = inp.view(new_t, self.subsampling_factor, *inp_shape[1:])
        return self.aggregate_fn(inp)

    def _get_datapoint_impl(self, item: int) -> Tuple[Tuple[torch.Tensor, ...], Tuple[torch.Tensor, ...]]:
        inputs, targets = self.parent.get_datapoint(item)

        inputs = tuple(self._process_tensor(inp) for inp in inputs)
        targets = tuple(self._process_tensor(tar) for tar in targets)

        return inputs, targets

    @property
    def seq_len(self):
        old_len = self.parent.seq_len
        if old_len is None:
            return None

        if isinstance(old_len, int):
            return ceil_div(old_len, self.subsampling_factor)

        return [ceil_div(old_l, self.subsampling_factor) for old_l in old_len]


class CacheTransform(Transform):
    def __init__(self, parent: Transform):
        """
        Caches the results from a previous transform in memory so that expensive calculations do not have to be
        recomputed.

        :param parent: Another transform which is used as the data source for this transform.
        """
        super(CacheTransform, self).__init__(parent)

        self.cache = {}

    def _get_datapoint_impl(self, item: int) -> Tuple[Tuple[torch.Tensor, ...], Tuple[torch.Tensor, ...]]:
        if item in self.cache:
            return self.cache[item]

        inputs, targets = self.parent.get_datapoint(item)
        self.cache[item] = (inputs, targets)

        return inputs, targets


class WindowTransform(Transform):
    def __init__(self, parent: Transform, window_size: int, step_size: int = 1, reverse: bool = False):
        """
        This transform produces sliding windows from input sequences. Incomplete windows (that can appear if
        `step_size>1`) will not be returned.

        :param parent: Another transform which is used as the data source for this transform.
        :param window_size: The size of each window.
        :param step_size: The step size at which the sliding window is moved along the sequence.
        :param reverse: If this is True, start the sliding window at the end of a sequence, instead of the start. Note that this will not reverse the order of sequences in the dataset and only applies within a single sequence.
        """
        super(WindowTransform, self).__init__(parent)
        self.window_size = window_size
        self.step_size = step_size
        self.reverse = reverse

    def inverse_transform_index(self, item) -> Tuple[int, int]:
        seq_len = self.parent.seq_len

        ts_index = window_start = 0
        if isinstance(seq_len, int):
            # Every sequence has the same length
            windows_per_seq = ceil_div(max((seq_len - self.window_size + 1), 0), self.step_size)
            ts_index, window_start = divmod(item, windows_per_seq)
            window_start *= self.step_size
        else:
            # Sequences have different length
            total_windows = old_total_windows = 0
            for i, seq_l in enumerate(seq_len):
                windows_per_seq = ceil_div(max((seq_l - self.window_size + 1), 0), self.step_size)
                old_total_windows = total_windows
                total_windows += windows_per_seq
                if total_windows > item:
                    ts_index = i
                    window_start = (item - old_total_windows) * self.step_size
                    break

        if self.reverse:
            window_start = seq_len - window_start - self.window_size

        return ts_index, window_start

    def _get_datapoint_impl(self, item: int) -> Tuple[Tuple[torch.Tensor, ...], Tuple[torch.Tensor, ...]]:
        old_i, start = self.inverse_transform_index(item)
        end = start + self.window_size
        inputs, targets = self.parent.get_datapoint(old_i)

        out_inputs = tuple(inp[start:end] for inp in inputs)
        out_targets = tuple(t[start:end] for t in targets)

        return out_inputs, out_targets

    def __len__(self):
        old_n = len(self.parent)
        old_ts = self.parent.seq_len
        if isinstance(old_ts, int):
            return old_n * ceil_div(max((old_ts - self.window_size + 1), 0), self.step_size)

        return sum(ceil_div(max((old_t - self.window_size + 1), 0), self.step_size) for old_t in old_ts)

    @property
    def seq_len(self):
        return self.window_size


class PredictionTargetTransform(WindowTransform):
    def __init__(self, parent: Transform, window_size: int, prediction_horizon: int, replace_labels: bool = False,
                 step_size: int = 1, reverse: bool = False):
        """
        Adds the last prediction_window points from the current inputs as targets for prediction objectives.

        :param parent: Another transform which is used as the data source for this transform.
        :param prediction_horizon: Number of datapoints that should be predicted.
        :param replace_labels: Whether the original labels should be replaced by the reconstruction target. If False, the reconstruction target will be added to the tuple of original labels.
        """
        super(PredictionTargetTransform, self).__init__(parent, window_size + prediction_horizon, step_size, reverse)

        self.input_window_size = window_size
        self.prediction_horizon = prediction_horizon
        self.replace_labels = replace_labels

    def _get_datapoint_impl(self, item: int) -> Tuple[Tuple[torch.Tensor, ...], Tuple[torch.Tensor, ...]]:
        inputs, targets = super(PredictionTargetTransform, self)._get_datapoint_impl(item)

        new_inputs = tuple(inp[:-self.prediction_horizon] for inp in inputs)
        new_targets = tuple(inp[-self.prediction_horizon:] for inp in inputs)

        if self.replace_labels:
            return new_inputs, new_targets

        targets = tuple(target[-self.prediction_horizon:] for target in targets)
        return new_inputs, targets + new_targets

    @property
    def seq_len(self) -> Union[int, List[int]]:
        return self.input_window_size


class OverlapPredictionTargetTransform(Transform):
    def __init__(self, parent: Transform, offset: int, replace_labels: bool = False):
        """
        Adds the sequence shifted by offset as the target.

        :param parent: Another transform which is used as the data source for this transform.
        :param offset: Number of steps ahead that should be predicted.
        :param replace_labels: Whether the original labels should be replaced by the reconstruction target. If False, the reconstruction target will be added to the tuple of original labels.
        """
        super(OverlapPredictionTargetTransform, self).__init__(parent)
        self.offset = offset
        self.replace_labels = replace_labels

    def _get_datapoint_impl(self, item: int) -> Tuple[Tuple[torch.Tensor, ...], Tuple[torch.Tensor, ...]]:
        inputs, targets = self.parent.get_datapoint(item)

        new_inputs = tuple(inp[:-self.offset] for inp in inputs)
        new_targets = tuple(inp[self.offset:] for inp in inputs)

        if self.replace_labels:
            return new_inputs, new_targets

        targets = tuple(target[self.offset:] for target in targets)
        return new_inputs, targets + new_targets

    @property
    def seq_len(self) -> Union[int, List[int]]:
        parent_seq_len = self.parent.seq_len
        if isinstance(parent_seq_len, int):
            return parent_seq_len - self.offset

        return [slen - self.offset for slen in parent_seq_len]


class LimitTransform(Transform):
    def __init__(self, parent: Transform, count: int):
        """
        Limits the amount of data points returned.

        :param parent: Another transform which is used as the data source for this transform.
        :param count: The max number of sequences that should be returned by this transform.
        """
        super(LimitTransform, self).__init__(parent)
        self.max_count = count

    def _get_datapoint_impl(self, item: int) -> Tuple[Tuple[torch.Tensor, ...], Tuple[torch.Tensor, ...]]:
        if item >= self.max_count:
            raise IndexError

        return self.parent.get_datapoint(item)

    def __len__(self):
        if len(self.parent) is not None:
            return min(self.max_count, len(self.parent))

        return None


# --- Artificial anomalies ---


class InjectArtificialAnomaliesTransform(Transform):

    def __init__(self, parent: Transform, n: int, min_length: int = 1, max_length: int = 1):
        """This Transform injects anomalies into the dataset.

        It expects the get_datapoint method of its parent to return a tuple of tuples of length 1.

        :param parent: Another transform which is used as the data source for this transform.
        :type parent: Transform
        :param n: Number of anomalies to insert.
        :type n: int
        :param min_length: Minimum length of anomalies.
        :type min_length: int
        :param max_length: Maximum length of anomalies.
        :type max_length: int
        """

        super(InjectArtificialAnomaliesTransform, self).__init__(parent)

        self._sample_intervals(n, min_length, max_length)

    def _sample_intervals(self, n: int, min_length: int, max_length: int):

        intervals_per_time_series = self._compute_intervals_per_time_series(n, min_length)

        # Collect the indices of all anomalies in a dictionary for each time series that has anomalies injected
        self.indices = {}

        for time_series, n_intervals in intervals_per_time_series.most_common():

            # Compute the boundaries of intervals that anomalies are to be injected in
            self.indices[time_series] = {
                'intervals': generate_intervals(n_intervals, min_length, max_length,
                                                self.parent.get_datapoint(time_series)[0][0].shape[0])
            }

            # Compute the anomalies for the intervals
            self.indices[time_series]['values'] = [
                self._inject_anomaly(self.parent.get_datapoint(time_series)[0][0][left:right]) for left, right in self.indices[time_series]['intervals']
            ]

    def _compute_intervals_per_time_series(self, n: int, min_length: int):

        sizes = [self.parent.get_datapoint(idx)[0][0].shape[0] for idx in range(len(self))]

        # Check if it is possible to sample n non overlapping windows from the dataset
        assert sum([size // min_length for size in sizes]) >= n

        candidate_time_series = {idx: sizes[idx] // min_length for idx in range(len(self)) if sizes[idx] // min_length}

        datapoint_indices = []

        for i in range(n):

            candidates = list(candidate_time_series.keys())
            total_size = sum(candidate_time_series.values())
            probabilities = [candidate_time_series[c] / total_size for c in candidates]

            time_series_index = np.random.choice(candidates, p=probabilities)

            if candidate_time_series[time_series_index] == 1:
                del candidate_time_series[time_series_index]
            else:
                candidate_time_series[time_series_index] -= 1

            datapoint_indices.append(time_series_index)

        intervals_per_time_series = Counter(datapoint_indices)

        return intervals_per_time_series

    def _get_datapoint_impl(self, item: int) -> Tuple[Tuple[torch.Tensor, ...], Tuple[torch.Tensor, ...]]:

        inputs, targets = self.parent.get_datapoint(item)

        if item in self.indices:

            for (l, r), value in zip(self.indices[item]['intervals'], self.indices[item]['values']):

                inputs[0][l:r]  = value
                targets[0][l:r] = 1

        return inputs, targets

    @abc.abstractmethod
    def _inject_anomaly(self, interval: torch.Tensor) -> torch.Tensor:
        """Injects an anomaly into an interval.

        :param interval: Interval in a time series of the dataset.
        :type interval: torch.Tensor
        :return: Interval of the same length to replace the input interval in the dataset.
        :rtype: torch.Tensor
        """
        raise NotImplementedError


class InjectIndependentArtificialAnomaliesTransform(InjectArtificialAnomaliesTransform):

    def __init__(self, parent: Transform, anomaly_fn: Callable, n: int, min_length: int = 1, max_length: int = 1):
        """Transform that injects anomalies, that only depend on the anomaly interval.

        :param parent: Another transform which is used as the data source for this transform.
        :type parent: Transform
        :param anomaly_fn: Callable that adds an anomaly to an interval and returns a torch.Tensor of the same size as its input.
        :type anomaly_fn: Callable
        :param n: Number of anomalies to insert.
        :type n: int
        :param min_length: Minimum length of anomalies.
        :type min_length: int
        :param max_length: Maximum length of anomalies.
        :type max_length: int
        """

        self.anomaly = anomaly_fn

        super(InjectIndependentArtificialAnomaliesTransform, self).__init__(parent, n, min_length, max_length)

    def _inject_anomaly(self, interval: torch.Tensor, index: int = -1, left_boundary: int = -1, right_boundary: int = -1) -> torch.Tensor:
        return self.anomaly(interval)


class InjectWindowsArtificialAnomaliesTransform(InjectArtificialAnomaliesTransform):

    def __init__(self, parent: Transform, mask_fn: Callable, n: int, min_length: int = 1, max_length: int = 1):
        """Transform that inject windows from somewhere else in the dataset as anomalies.

        :param parent: Another transform which is used as the data source for this transform.
        :type parent: Transform
        :param mask_fn: Callable that computes a mask to the features of an interval.
        :type mask_fn: Callable
        :param n: Number of anomalies to insert.
        :type n: int
        :param min_length: Minimum length of anomalies.
        :type min_length: int
        :param max_length: Maximum length of anomalies.
        :type max_length: int
        """

        self.mask = mask_fn

        super(InjectWindowsArtificialAnomaliesTransform, self).__init__(parent, n, min_length, max_length)

    def _sample_intervals(self, n: int, min_length: int, max_length: int):

        intervals_per_time_series = self._compute_intervals_per_time_series(2*n, min_length)

        all_indices = list(intervals_per_time_series.elements())

        np.random.shuffle(all_indices)

        # Collect all the indices of time series in the dataset, that anomalies are to be inserted in
        time_series_containing_anomalies = Counter(all_indices[:n])

        # self.reference_time_series is a sorted list of indices of time series in the dataset,
        # where ich index is repeated as often as how many reference intervals it contains
        # self.referent_indices is a list containing the relative indices of intervals in the time series
        # i.e. reference interval i is the self.reference_indices[i]th interval in self.reference_time_series[i]
        self.reference_time_series = sorted(list(Counter(all_indices[n:]).elements()))
        self.reference_indices     = []

        # self.intervals contains all intervals for a time series in the dataset
        self.intervals = {}

        self.indices = {}

        # Generate the intervals for each time series
        for time_series in sorted(intervals_per_time_series.keys()):

            n_intervals = intervals_per_time_series[time_series]

            intervals = generate_intervals(n_intervals,
                                           min_length,
                                           max_length,
                                           self.parent.get_datapoint(time_series)[0][0].shape[0])

            # Collect the indices of intervals where anomalies are to be inserted
            if time_series in time_series_containing_anomalies.keys():
                anomaly_indices = sorted(np.random.choice(list(range(len(intervals))),
                                                          size=time_series_containing_anomalies[time_series],
                                                          replace=False))
            else:
                anomaly_indices = []

            self.intervals[time_series] = intervals

            self.reference_indices.extend([idx for idx in range(len(intervals)) if idx not in anomaly_indices])

            if anomaly_indices:
                self.indices[time_series] = {
                    'intervals': [intervals[idx] for idx in anomaly_indices],
                    'values': []
                }

        for time_series in self.indices.keys():

            for idx, (left_boundary, right_boundary) in enumerate(self.indices[time_series]['intervals']):

                anomaly = self._inject_anomaly(self.parent.get_datapoint(time_series)[0][0][left_boundary:right_boundary])

                # If the reference interval is smaller than the interval that the anomaly is to be inserted in,
                # the shape is different and thus has to be updated
                self.indices[time_series]['intervals'][idx] = (left_boundary, left_boundary + anomaly.shape[0])

                self.indices[time_series]['values'].append(anomaly)

    def _inject_anomaly(self, interval: torch.Tensor) -> torch.Tensor:

        # Draw a random reference interval
        reference_index = np.random.choice(list(range(len(self.reference_indices))))

        # Get the index of the time series reference interval is in
        # and the index in the list of intervals in that time series
        reference_time_series = self.reference_time_series[reference_index]
        reference_interval    = self.reference_indices[reference_index]

        # Remove elements from reference lists
        self.reference_time_series = self.reference_time_series[:reference_index] + self.reference_time_series[reference_index + 1:]
        self.reference_indices     = self.reference_indices[:reference_index] + self.reference_indices[reference_index + 1:]

        # Adjust the size of the intervals to match each other

        left_boundary, right_boundary = self.intervals[reference_time_series][reference_interval]

        if interval.shape[0] < right_boundary - left_boundary:
            right_boundary = left_boundary + interval.shape[0]

        # Extract the anomaly
        anomaly = self.parent.get_datapoint(reference_index)[0][0][left_boundary:right_boundary]

        mask = self.mask(anomaly)

        # Apply mask and return constructed anomaly
        return torch.mul(interval[:right_boundary-left_boundary], mask) + torch.mul(anomaly, 1 - mask)


def full_mask(inputs: torch.Tensor) -> torch.Tensor:
    return torch.zeros_like(inputs)


# -------------------------------------


class DatasetSource(Transform):
    def __init__(self, dataset: torch.utils.data.Dataset, start: Union[int, List[int]] = 0,
                 end: Union[int, List[int]] = 0, axis: str = 'batch'):
        """
        This acts as a source transform (meaning it has no parent) that simply returns sequences from a given dataset.
        It can be constrained to return only a specific part of the data.

        :param dataset: The datset from which to take points.
        :param start: Start index for this dataset. Please see below for a more detailed explanation.
        :param end: End index for this dataset. Please see below for a more detailed explanation.
        :param axis: Can be either 'batch' or 'time'. In 'batch' mode, this simply returns only the sequences indexed from start to end. 'time' mode is used for datasets that contain only one long time series. That time series will be cut according to start and end.
        """
        super(DatasetSource, self).__init__(None)

        self.dataset = dataset
        self.axis = axis if axis == 'time' else 'batch'

        data_len = len(dataset)
        if self.axis == 'time':
            data_len = dataset.seq_len
            if isinstance(data_len, int):
                data_len = [data_len] * len(dataset)

            if isinstance(start, int):
                start = [start] * len(dataset)

            if isinstance(end, int):
                end = [end] * len(dataset)

            assert all(-l <= s <= l for l, s in zip(data_len, start))
            assert all(-l < e <= l for l, e in zip(data_len, end))

            # TODO: handle negative indices
            self.start = start
            self.end = end
        else:
            assert isinstance(start, int)
            assert isinstance(end, int)
            assert -data_len <= start < data_len
            assert -data_len < end <= data_len

            self.start = data_len + start if start < 0 else start
            self.end = data_len + end if end <= 0 else end

    def _get_datapoint_impl(self, item) -> Tuple[Tuple[torch.Tensor, ...], Tuple[torch.Tensor, ...]]:
        if self.axis == 'batch':
            index = self.start + item
            return self.dataset[index]

        # Slice in time dimension
        inputs, targets = self.dataset[item]
        inputs = tuple(inp[self.start[item]:self.end[item]] for inp in inputs)
        targets = tuple(target[self.start[item]:self.end[item]] for target in targets)

        return inputs, targets

    def __len__(self):
        return len(self.dataset) if self.axis == 'time' else (self.end - self.start)

    @property
    def seq_len(self):
        return self.dataset.seq_len if self.axis == 'batch' else [(end - start) for start, end in zip(self.start, self.end)]

    @property
    def num_features(self):
        return self.dataset.num_features


def make_dataset_split(dataset: torch.utils.data.Dataset, *splits: float, axis: str = 'batch'):
    """
    Create `DatasetSource`s for different parts of a given dataset.

    :param dataset: The dataset, for which the split should be done.
    :param splits: This should be the percentages of the dataset in each split. Will be normalized to 100%.
    :param axis: The axis along which to split the dataset. Please see `DatasetSource` for a more detailed explanation.
    :return: This will return a generator that yields `DatasetSources` according to the specified splits.
    """
    axis = axis if axis == 'time' else 'batch'

    # Compute relative split percentages
    percent = torch.tensor(splits, dtype=torch.float64)
    percent /= torch.sum(percent)

    # Translate percentages into index ranges
    if axis == 'batch':
        data_len = len(dataset)
        lengths = torch.floor(percent * data_len).to(torch.int64)
        rest = data_len - torch.sum(lengths).item()
        rest = torch.tensor([1]*rest + [0]*(len(lengths) - rest))
        lengths += rest

        cum_bound = 0
        for l in lengths:
            start = cum_bound
            cum_bound += l.item()
            yield DatasetSource(dataset, start, cum_bound, axis=axis)
    else:
        data_len = dataset.seq_len
        if isinstance(data_len, int):
            data_len = [data_len] * len(dataset)

        starts, ends = [[] for _ in percent], [[] for _ in percent]
        for data_l in data_len:
            lengths = torch.floor(percent * data_l).to(torch.int64)
            rest = data_l - torch.sum(lengths).item()
            rest = torch.tensor([1] * rest + [0] * (len(lengths) - rest))
            lengths += rest

            cum_bound = 0
            for i, l in enumerate(lengths):
                start = cum_bound
                cum_bound += l.item()
                starts[i].append(start)
                ends[i].append(cum_bound)

        for start, end in zip(starts, ends):
            yield DatasetSource(dataset, start, end, axis=axis)


class PipelineDataset(torch.utils.data.Dataset, DatasetMixin):
    def __init__(self, sink_transform: Transform):
        """
        Dataset that can be used with a `torch.utils.data.Dataloader` and executes a pipeline of transforms to retrieve
        its datapoints.

        :param sink_transform: The last transform in the pipeline that should be queried for data points.
        """
        super(PipelineDataset, self).__init__()

        self.sink_transform = sink_transform

    def __iter__(self):
        for i in range(len(self)):
            yield self[i]

    def __getitem__(self, item) -> Tuple[Tuple[torch.Tensor, ...], Tuple[torch.Tensor, ...]]:
        return self.sink_transform.get_datapoint(item)

    def __len__(self):
        return len(self.sink_transform)

    @property
    def seq_len(self) -> Union[int, List[int]]:
        return self.sink_transform.seq_len

    @property
    def num_features(self) -> Union[int, Tuple[int, ...]]:
        return self.sink_transform.num_features

    @staticmethod
    def get_default_pipeline() -> Dict[str, Dict[str, Any]]:
        return {}

    @staticmethod
    def __concatenate_and_save(out_inputs: Tuple[List[torch.Tensor], ...],
                               out_targets: Tuple[List[torch.Tensor], ...], file_name: str, batch_dim: int = 0):
        out_inputs = tuple(torch.stack(inp, dim=batch_dim) for inp in out_inputs)
        out_targets = tuple(torch.stack(tar, dim=batch_dim) for tar in out_targets)
        torch.save((out_inputs, out_targets), file_name)

    @staticmethod
    def get_feature_names() -> List[str]:
        raise NotImplementedError

    def save(self, path: str, chunk_size: int = 0, batch_dim: int = 0):
        """
        Save this dataset as it would be returned after all processing by transforms is done.

        :param path: The folder in which to save the dataset.
        :param chunk_size: The maximum number of data points that should be saved in one file. If there are more data points than this value, multiple files will be created. Set this to 0 to save the entire dataset in one file.
        :param batch_dim: All (or `chunk_size`) datapoints will be stacked along this axis in a new tensor that is the saved to disk.
        """
        os.makedirs(path, exist_ok=True)

        out_inputs = None
        out_targets = None
        for i, (inputs, targets) in enumerate(self):
            if out_inputs is None:
                out_inputs = tuple([] for _ in inputs)
            for out_list, inp in zip(out_inputs, inputs):
                out_list.append(inp)

            if out_targets is None:
                out_targets = tuple([] for _ in targets)
            for out_list, target in zip(out_targets, targets):
                out_list.append(target)

            if chunk_size > 0 and i % chunk_size == chunk_size - 1:
                self.__concatenate_and_save(out_inputs, out_targets, os.path.join(path, f'data_{i // chunk_size}.pth'),
                                            batch_dim)
                out_inputs = out_targets = None

        # Save last chunk
        if len(out_inputs[0]) > 0:
            file_name = 'data_0.pth' if chunk_size == 0 else f'data_{i // chunk_size}.pth'
            self.__concatenate_and_save(out_inputs, out_targets, os.path.join(path, file_name), batch_dim)
