import torch
import numpy as np
from scipy.stats.qmc import Sobol


class QmcWrapper:
	"""
	Wrapper for Randomize quasi-Monte Carlo (RQMC) sampling
	"""
	def __init__(self, dim, device):
		self.dim = dim
		self.device = device
		self.cpu_device = torch.device('cpu')

	def sample(self, gen_num, batch_size=1, enable_qmc=True):
		"""
		Sample from the Sobol sequence
		:param gen_num: number of samples to generate, should be power of 2
		:param batch_size: number of batches to generate parallel
		:param enable_qmc: whether to use RQMC sampling
		:return: uniform samples in [0, 1]^dim, shape:(batch_size, gen_num, dim)
		"""
		order = int(np.ceil(np.log2(gen_num)))  # get the order of the Sobol sequence

		# use RQMC sampling
		if enable_qmc:
			assert gen_num == 2 ** order, "RQMC sample size should be power of 2"
			unif_sample = torch.zeros(batch_size, gen_num, self.dim, device=self.cpu_device)
			for index in range(batch_size):
				generator = Sobol(d=self.dim)  # create Sobol generator
				unif_sample[index] = torch.from_numpy(generator.random_base2(m=order))
			unif_sample = unif_sample.to(self.device)

		# use MC sampling
		else:
			unif_sample = torch.rand(batch_size, gen_num, self.dim, device=self.device)

		# return the uniform samples
		return unif_sample

	def normal_transform(self, unif_sample):
		"""
		Transform uniform samples to standard normal samples
		:param unif_sample: uniform samples in [0, 1]^dim, shape:(batch_size, gen_num, dim)
		:return: standard normal samples, shape:(batch_size, gen_num, dim)
		"""
		return torch.erfinv(2 * unif_sample - 1) * np.sqrt(2)

