import torch
from src.utils import *
from torch.distributions import MultivariateNormal

def uniform(num_samples, center, width, num_dims, mean=None, cov=None):
    samples = unif(num_samples, center, width, num_dims)
    sphere_vol = sphere_volume(width, num_dims)
    IS_pdf = 1 / sphere_vol

    return samples, IS_pdf

def IS_with_gaussian(num_samples, center, width, num_dims, mean=None, cov=None):
    # uniform sampling
    unif_samples = unif(num_samples, center, width, num_dims)

    if mean is None:
        mean = center
    if cov is None:
        cov = torch.diag(torch.ones((num_dims,))*width)

    # sampling from Gaussian
    is_dist = MultivariateNormal(mean, cov)
    is_samples = is_dist.sample((num_samples,))
    within_sphere = torch.linalg.norm(is_samples, axis=-1) < width
    inside = within_sphere.sum()
    ratio = len(is_samples) / inside
    is_samples = is_samples[within_sphere]

    all_samples = torch.vstack([unif_samples, is_samples])

    # IS sampling pdf
    # uniform pdf
    sphere_vol = sphere_volume(width, num_dims)
    IS_pdf = 1 / sphere_vol

    # gaussian pdf, corrected by truncation factor
    IS_pdf += torch.exp(is_dist.log_prob(all_samples)) * ratio
    IS_pdf *= 1 / 2

    return all_samples, IS_pdf
def sobol(num_samples, center, width, num_dims, mean=None, cov=None):
    samples = torch.zeros((num_samples, num_dims))
    soboleng = torch.quasirandom.SobolEngine(dimension=num_dims, scramble=True, seed=0)
    counter = 0
    while counter < num_samples:
        s = (soboleng.draw(num_samples)-0.5) * 2 * width + center
        s = s[torch.linalg.norm(s - center, axis=-1) < width]

        samples[counter:counter + len(s)] = s[:min(num_samples - counter, len(s))]
        counter += len(s)
    sphere_vol = sphere_volume(width, num_dims)
    IS_pdf = 1 / sphere_vol

    return samples, IS_pdf