
import numpy as np


class RandomSlice(object):

    def __init__(self, n_frames, min_frames=None):
        self.n_frames = n_frames
        self.min_frames = min_frames

    def __call__(self, sample):
        X = sample['inputs']['X']

        if self.min_frames is not None:
            frames = np.random.randint(self.min_frames, self.n_frames + 1)
        else:
            frames = self.n_frames

        # apply random cyclic shift
        max_idx = X.shape[-1] - frames
        start = np.random.randint(0, max_idx)
        stop = start + frames
        X_slice = X[:, :, start:stop]

        return {'inputs': {'X': X_slice},
                'targets': sample['targets']}
