import torch
import numpy as np

def draw_uniform_cube(batch_size, d,  precision = torch.float64):
    return torch.rand(batch_size, d, dtype=precision)


def example_random_heavy_tail():
    """
    """
    d = 10
    # def draw_samples_mu(batch_size, d, precision = torch.float64):
    #     return torch.rand(batch_size, d, dtype=torch.float64)
    def draw_samples_mu(k, d, precision=torch.float32):
        """
        Samples `k` points from the probability distribution in `d` dimensions
        with density proportional to 1 / (1 + ||x||^2)^(d+1)/2.
        """
        def radial_density(r, d):
            """Radial density function proportional to r^(d-1) / (1 + r^2)^((d+1)/2)."""
            return r**(d-1) / (1 + r**2)**((d+1)/2)

        def sample_radius(d, size):
            """
            Samples the radius `r` using rejection sampling in batches.
            """
            scale = 1.0  # Exponential parameter
            max_density = radial_density(torch.tensor(0.0, dtype=precision), d)  # Upper bound for rejection
            if max_density == 0 or torch.isnan(max_density):
                max_density = 1.0  # Safeguard for edge cases

            radii = []
            batch_size = 10000
            while len(radii) < size:
                # Sample a batch of proposals
                r_proposal = torch.distributions.Exponential(rate=1.0 / scale).sample((batch_size,))
                # Compute acceptance probabilities for the batch
                acceptance_probs = radial_density(r_proposal, d) / max_density
                # Perform rejection sampling
                uniform_samples = torch.rand(batch_size, dtype=precision)
                accepted = r_proposal[uniform_samples < acceptance_probs]
                radii.extend(accepted.tolist())
            return torch.tensor(radii[:size], dtype=precision)

        def sample_unit_sphere(k, d):
            """
            Samples `k` points uniformly from the unit sphere in `d` dimensions.
            """
            z = torch.randn((k, d), dtype=precision)  # Gaussian samples
            return z / torch.norm(z, dim=1, keepdim=True)  # Normalize to unit sphere

        # Sample radii and directions
        radii = sample_radius(d, k)
        directions = sample_unit_sphere(k, d)

        # Combine radii and directions to get final samples
        samples = radii.unsqueeze(1) * directions
        return samples
    #load the generated data 
    data =  np.load("Main_Examples/generated_data_heavy_Tail.npz")

    nu_points = data['Y']
    nu_weights = data['weights']
    g_opt = data['g_opt']

    #convert to pytorch
    nu_points = torch.tensor(nu_points, dtype=torch.float64)
    nu_weights = torch.tensor(nu_weights, dtype=torch.float64)
    g_opt = torch.tensor(g_opt, dtype=torch.float64)

    max_cost = 1
    
    return d, draw_samples_mu, nu_points, nu_weights, max_cost, g_opt

def example_random_quadra():
    """
    Illustrates a non symmetric problem, where the weights of \nu are not uniform.
    \mu is fixed to be the unform measure on [0,1]^d with d = 50.
    We then fix a g_opt randomly, and calculate the weights of \nu corresponding to the Laguerre cells L(g_opt)
    Here M = 10 and cost = ||x-y||^2
    """
    d = 50
    def draw_samples_mu(batch_size, d, precision = torch.float64):
        return torch.rand(batch_size, d, dtype=torch.float64)

    #load the generated data 
    data =  np.load("Main_Examples/generated_data_quadratic.npz")


    nu_points = data['Y']
    nu_weights = data['weights']
    g_opt = data['g_opt']

    #convert to pytorch
    nu_points = torch.tensor(nu_points, dtype=torch.float64)
    nu_weights = torch.tensor(nu_weights, dtype=torch.float64)
    g_opt = torch.tensor(g_opt, dtype=torch.float64)

    max_cost = 1
    
    return d, draw_samples_mu, nu_points, nu_weights, max_cost, g_opt

def example_random_exponential():
    """
    Illustrates a non symmetric problem, where the weights of \nu are not uniform.
    \mu is fixed to be the unform measure on [0,1]^d with d = 50.
    We then fix a g_opt randomly, and calculate the weights of \nu corresponding to the Laguerre cells L(g_opt)
    Here M = 10 and cost = exp(||x-y||^2)
    """
    d = 50
    def draw_samples_mu(batch_size, d, precision = torch.float64):
        return torch.rand(batch_size, d, dtype=torch.float64)

    #load the generated data 
    data =  np.load("Main_Examples/generated_data_exponential.npz")

    nu_points = data['Y']
    nu_weights = data['weights']
    g_opt = data['g_opt']

    #convert to pytorch
    nu_points = torch.tensor(nu_points, dtype=torch.float64)
    nu_weights = torch.tensor(nu_weights, dtype=torch.float64)
    g_opt = torch.tensor(g_opt, dtype=torch.float64)

    max_cost = 1
    
    return d, draw_samples_mu, nu_points, nu_weights, max_cost, g_opt

def example_random_power1_5():
    """
    Illustrates a non symmetric problem, where the weights of \nu are not uniform.
    \mu is fixed to be the unform measure on [0,1]^d with d = 50.
    We then fix a g_opt randomly, and calculate the weights of \nu corresponding to the Laguerre cells L(g_opt)
    Here M = 10 and cost = exp(||x-y||^2)
    """
    d = 50
    def draw_samples_mu(batch_size, d, precision = torch.float64):
        return torch.rand(batch_size, d, dtype=torch.float64)

    #load the generated data 
    data =  np.load("Main_Examples/generated_data_power1_5.npz")

    nu_points = data['Y']
    nu_weights = data['weights']
    g_opt = data['g_opt']

    #convert to pytorch
    nu_points = torch.tensor(nu_points, dtype=torch.float64)
    nu_weights = torch.tensor(nu_weights, dtype=torch.float64)
    g_opt = torch.tensor(g_opt, dtype=torch.float64)

    max_cost = 1
    
    return d, draw_samples_mu, nu_points, nu_weights, max_cost, g_opt



