
import torch

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

class MinibatchSampler():
    
    def __init__(self, n_samples, batch_size=1):
        self.n_samples = n_samples
        self.batch_size = batch_size
        self.n_batches = (n_samples + batch_size - 1) // batch_size
        self.batch_order = torch.randperm(self.n_batches).to(device)
        self.i_batch = 0

    def get_batch(self):
        idx = self.batch_order[self.i_batch].item()
        selector = slice(idx * self.batch_size, min((idx + 1) * self.batch_size, self.n_samples))
        self.i_batch += 1
        if self.i_batch == self.n_batches:
            self.batch_order =torch.randperm(self.n_batches).to(device)
            self.i_batch = 0

        weight = self.batch_size / self.n_samples
        if idx == self.n_batches - 1 and self.n_samples % self.batch_size != 0:
            weight = (self.n_samples % self.batch_size) / self.n_samples

        return selector, (idx, weight)