# coding: utf-8

from torch.utils.data import DataLoader
from torch.utils.data.sampler import RandomSampler,SequentialSampler,BatchSampler
from torch.utils.data.distributed import DistributedSampler

def get_data_loader(dataset, batch_size=1, shuffle=False,
					num_workers=1, random_seed=111, ddp=False, rank=None):
	if ddp:
		assert not rank is None, 'rank must be set for ddp=True'
		sampler = DistributedSampler(dataset, rank=rank, shuffle=shuffle,
										drop_last=shuffle,
										seed=random_seed)
		drop_last = True
	elif shuffle:
		sampler = RandomSampler(dataset, replacement=False, seed=random_seed)
		drop_last = True
	else:
		sampler = SequentialSampler(dataset)
		drop_last = False
	batch_sampler = BatchSampler(sampler, batch_size, drop_last=drop_last)

	data_loader = DataLoader(dataset, num_workers=num_workers,
					batch_sampler=batch_sampler,
					collate_fn=getattr(dataset, 'collate_fn', None)
					)
	return data_loader