import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from scipy.spatial import ConvexHull
import json
import torch
from src.data_representation import Batch
from tqdm import tqdm
from mpl_toolkits.mplot3d.art3d import Poly3DCollection


def highd_ball_ptset(r, num_points=100, dimensions=3):
    n = num_points // 2

    
    # Generate random points
    large_points = np.random.normal(size=(n, dimensions))
    large_points /= np.linalg.norm(large_points, axis=1)[:, np.newaxis]
    
    # Scale to be uniformly distributed within a ball
    radii = np.random.rand(n) ** (1 / dimensions)
    large_points *= radii[:, np.newaxis] * r

    smaller = np.random.normal(size=(n, dimensions))
    smaller /= np.linalg.norm(smaller, axis=1)[:, np.newaxis]
    
    # Scale to be uniformly distributed within a ball
    small_radii = np.random.rand(num_points) ** (1 / dimensions)
    smaller *= radii[:, np.newaxis] * (r//2)

    
    points = np.vstack((large_points, smaller))

    
    # Compute the convex hull
    hull = ConvexHull(points)
    #print(f'elapsed time: {time.time() - start}')
    
    labels = np.zeros(num_points, dtype=int)
    labels[hull.vertices] = 1
    
    # Combine the points and labels into one dataset
    dataset = np.hstack((points, labels[:, np.newaxis]))

    np.random.shuffle(dataset)
    
    return dataset

def generate_dataset_mp(r, num_points=500, dimensions=3, num_samples=3000, num_workers=None):
    args = [(r, num_points, dimensions)] * num_samples
    with mp.Pool(processes=num_workers or mp.cpu_count()) as pool:
        dataset = pool.starmap(highd_ball_ptset, args)
    return dataset

def main():
    dataset = generate_dataset_mp(5, num_points=500, dimensions=10, num_samples=3000, num_workers = 20)
    np.save('/data/oren/coreset/data/10d_uniform_500.npy', dataset)

    test_dataset = generate_dataset_mp(5, num_points=500, dimensions=10, num_samples=3000, num_workers = 20)
    np.save('/data/oren/coreset/data/10d_uniform_500.npy', test_dataset)


if __name__ == "__main__":
    main()