import numpy as np
import secrets
from numpy.random import MT19937, RandomState
from scipy.stats import norm, qmc, chi, gamma
from scipy.special import gamma as gamma_func

big_constant = -1000000000
minimum_sample_size = 2**14

class _GeneralSphericalSymmetricSampleGenerator:
    """
    X0, X1: neighbouring queries
    dimensionality: size of the query
    claimed_epsilon: claimed privacy parameter

    probability_of_natural_sample
    probability_of_alternative_sample
    alternative_sample_noise
    """
    def __init__(self, kwargs):
        self.X0 = kwargs["dataset_settings"]["database_0"]
        self.X1 = kwargs["dataset_settings"]["database_1"]

        if not isinstance(self.X0, np.ndarray):
            self.X0 = np.array(self.X0)
        if not isinstance(self.X1, np.ndarray):
            self.X1 = np.array(self.X1)
        self.X0 = np.reshape(self.X0, (1, self.X0.size))
        self.X1 = np.reshape(self.X1, (1, self.X1.size))

        self.one = np.ones((1,))
        self.zero = np.zeros((1,))
        self.claimed_epsilon = kwargs["dataset_settings"]["claimed_epsilon"]
        self.sigma = kwargs["mechanism_settings"]["sigma"]
        self.error_metric = kwargs["error_metric"]

        assert (self.X1.size == self.X0.size)
        self.dimensionality = self.X1.size

        self.probability_of_natural_sample = 1 / (np.exp(self.claimed_epsilon))
        self.probability_of_alternative_sample = 1 - self.probability_of_natural_sample
        # we output an alternate sample that has negligible probability
        # (approx exp(-10000000)) of being generated by the laplace distribution
        self.alternative_sample_noise = (
                big_constant * self.sigma * np.ones_like(self.X1)
        )

        self.rng_qmc = qmc.Sobol(d=1, scramble=True, seed = secrets.randbits(128))
        self.rng_np = RandomState(MT19937(secrets.randbits(128)))
        self.rng = qmc.Sobol(d=1+self.dimensionality, scramble=True, seed = secrets.randbits(128))
    
    def gen_radial_noise_from_lds(self, U, num_samples):
        raise NotImplementedError("Please implement this method in the child class")
    
    def reset_rng(self):
        self.rng_qmc = qmc.Sobol(d=1, scramble=True, seed = secrets.randbits(128))
        self.rng_np = RandomState(MT19937(secrets.randbits(128)))
        self.rng = qmc.Sobol(d=1+self.dimensionality, scramble=True, seed = secrets.randbits(128))

    def gen_samples(self, num_samples, generate_positive_sample, reset_rng=False):
        if reset_rng:
            self.reset_rng()
        # if generate_positive_sample is True:
        #     assert num_samples*self.probability_of_natural_sample >= minimum_sample_size, f"Minimum sample size must be at least {minimum_sample_size/self.probability_of_natural_sample}"
        
        if self.dimensionality < 150:
            U = self.rng.random(num_samples)

            # 2. Map the first d coordinates to standard normals (vectorized)
            G = norm.ppf(U[:, :self.dimensionality])  # Shape: (n, d)
        
            # 3. Normalize to get points on the sphere (vectorized)
            norms = np.linalg.norm(G, axis=1, keepdims=True)
            V = G / norms          # Shape: (n, d)
        
            # 4. Map the last coordinate to standard normal, then scale by sigma (vectorized)
            z = self.gen_radial_noise_from_lds(U[:, self.dimensionality], num_samples)
        
            # 5. Multiply: (n,1) * (n, d) to get (n, d)
            noise = z * V       # Shape: (n, d)

            # estimated_second_moment = np.mean(z**2)
            # print(f"Estimated E[R^2]: {estimated_second_moment}")

        else:
            X = self.rng_np.normal(loc=0, scale=1, size=(num_samples, self.dimensionality))
            V = X / np.linalg.norm(X, axis=1, keepdims=True)

            U = self.rng_qmc.random(num_samples)
            z = self.gen_radial_noise_from_lds(U, num_samples)
        
            # 5. Multiply: (n,1) * (n, d) to get (n, d)
            noise = z * V       # Shape: (n, d)


        if generate_positive_sample:
            X = self.X1
            y = self.one * np.ones(num_samples)

            p = np.random.uniform(0, 1, num_samples) > self.probability_of_alternative_sample
            p = p.reshape((num_samples, 1)) * np.ones((num_samples, X.size))
            return {'X': X + (p * noise + (1 - p) * self.alternative_sample_noise), 'y':  y}
        else:
            X = self.X0
            y = self.zero * np.ones(num_samples)
            return {'X': X + noise, 'y': y}

class SphericalGaussianSampleGenerator(_GeneralSphericalSymmetricSampleGenerator):
    def __init__(self, kwargs):
        super().__init__(kwargs)
        
        self.k = kwargs["mechanism_settings"]["k"]
        self.noise_budget = kwargs["mechanism_settings"]["noise_budget"]

        if self.error_metric == 'l2':
            self.sigma = np.sqrt(self.noise_budget/self.k)
        elif self.error_metric == 'l1':
            self.sigma = self.noise_budget*gamma_func(self.k/2)/(np.sqrt(2)*gamma_func((self.k+1)/2))
        else:
            raise ValueError(f"Error metric {self.error_metric} not supported")
    
    def gen_radial_noise_from_lds(self, U, num_samples):
        return self.sigma * chi.ppf(U, df=self.k).reshape((num_samples, 1))

class SphericalExponentialSampleGenerator(_GeneralSphericalSymmetricSampleGenerator):
    def __init__(self, kwargs):
        super().__init__(kwargs)
        self.noise_budget = kwargs["mechanism_settings"]["noise_budget"]

        if self.error_metric == 'l2':
            self.scale = 1.0 / np.sqrt(2/self.noise_budget)
        else:
            raise ValueError(f"Error metric {self.error_metric} not supported")
    
    def gen_radial_noise_from_lds(self, U, num_samples):
        return -np.log(U).reshape((num_samples, 1)) * self.scale

class SphericalGammaSampleGenerator(_GeneralSphericalSymmetricSampleGenerator):
    def __init__(self, kwargs):
        super().__init__(kwargs)
        self.noise_budget = kwargs["mechanism_settings"]["noise_budget"]
        self.k = kwargs["mechanism_settings"]["k"]

        if self.error_metric == 'l2':
            self.theta = np.sqrt(self.noise_budget / (self.k * (self.k + 1)))
        elif self.error_metric == 'l1':
            self.theta = self.noise_budget/self.k
        else:
            raise ValueError(f"Error metric {self.error_metric} not supported")
    
    def gen_radial_noise_from_lds(self, U, num_samples):
        return gamma(a=self.k, scale=self.theta).ppf(U).reshape((num_samples, 1))


class SphericalUniformSampleGenerator(_GeneralSphericalSymmetricSampleGenerator):
    def __init__(self, kwargs):
        super().__init__(kwargs)
    
    def gen_radial_noise_from_lds(self, U, num_samples):
        raise NotImplementedError("This noise is bad, don't use it")
        return self.sigma