import torch

import numpy as np

class DefaultSampler(torch.utils.data.Sampler):
    
    def __init__(self, inds, shuffle=False, epochs=1):
        self.inds = np.array(inds)
        self.shuffle= shuffle
        self.epochs = epochs

    def __iter__(self):
        all_inds = []
        for e in range(self.epochs):
            if self.shuffle:
                np.random.shuffle(self.inds)
            all_inds += self.inds.tolist()

        return iter(all_inds)
        
    def __len__(self):
        return len(self.inds) * self.epochs
    
    
def load_sampler(stream_inds, config):
    return DefaultSampler(stream_inds)
