import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from scipy.spatial import ConvexHull
import json
import torch
from src.data_representation import Batch
from mpl_toolkits.mplot3d.art3d import Poly3DCollection


def highd_ball_ptset(num_points=100, dimensions=3):
    # Generate random points within a high-dimensional ball
    points = np.random.normal(size=(num_points, dimensions))
    points /= np.linalg.norm(points, axis=1)[:, np.newaxis]
    
    # Scale to be uniformly distributed within a ball
    radii = np.random.rand(num_points) ** (1 / dimensions)
    points *= radii[:, np.newaxis]

    #Wider bounding box
    points *= 5
    
    # Compute the convex hull of the point set
    hull = ConvexHull(points)
    
    # Create labels: 1 if the point is on the convex hull, 0 otherwise
    labels = np.zeros(num_points, dtype=int)
    labels[hull.vertices] = 1
    
    # Combine the points and labels into one dataset
    dataset = np.hstack((points, labels[:, np.newaxis]))
    
    return dataset

print('creating dataset')
dataset = [highd_ball_ptset(num_points= 500, dimensions=10) for _ in range(500)]
np.save('/data/oren/coreset/data/10d_ball_500.npy', dataset)