import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from scipy.spatial import ConvexHull
import miniball
from geotorch.sphere import uniform_init_sphere_ as unif_sphere
from tqdm import tqdm

import json
import h5py
import torch
from src.data_representation import Batch
from mpl_toolkits.mplot3d.art3d import Poly3DCollection
import itertools



def subsample_modelnet(n, data):
    np.random.seed(42)

    col_indices = np.random.choice(data.shape[1], size=n, replace=False)
    return data[:, col_indices]

def format_meb_points(ptset):
    
    center, radius_squared = miniball.get_bounding_ball(ptset)
    radius = np.sqrt(radius_squared)
    
    entry = {
    "pointset": ptset.tolist(),
    "meb_center": center.tolist(),
    "meb_radius": radius
        }
    
    return entry

    

def main():

    f = h5py.File('/data/sam/modelnet/data/modelnet40_ply_hdf5_2048/ply_data_train0.h5', 'r')
    n = f['data'].shape[0]
    data = f['data'][:n]

    np.random.seed(42)
    num_ptsets = data.shape[1]
    scales = np.random.uniform(1, 5, (num_ptsets, 1, 1))
    shifts = np.random.uniform(1, 5, (num_ptsets, 1, 3))

    transformed_data = data * scales + shifts


    datasets = []
    sizes = list(range(400, 2000, 200))
    for num in sizes:
        datasets.append(subsample_modelnet(num, transformed_data))

    for num, dataset in zip(sizes, datasets):
        print(f'Size: {num}')
        dataset_meb = [format_meb_points(ptset) for ptset in tqdm(dataset)]

        with open(f'/data/oren/coreset/data/scaled_subsampled_{num}_modelnet_meb.json', 'w') as f:
            json.dump(dataset_meb, f)
    

if __name__ == "__main__":
    main()