import h5py

from experiments.neural_datasets.shapenet_sdf import ShapeNetSDF

dataset = ShapeNetSDF(root="datasets/ShapeNetSDF", seed=0)

num_train = 42_470
num_test = 10_000
num_points = dataset[0][0].shape[0]

train_f = h5py.File("datasets/ShapeNetSDF/shapenet_train.h5", "w")
test_f = h5py.File("datasets/ShapeNetSDF/shapenet_test.h5", "w")

train_coordinates = train_f.create_dataset("points", (num_train * num_points, 3), dtype='f2')
train_distances = train_f.create_dataset("sdf", (num_train * num_points), dtype='f2')
train_indices = train_f.create_dataset("indices", (num_train,), dtype='i4')
train_labels = train_f.create_dataset("labels", (num_train,), dtype='i4')

test_coordinates = test_f.create_dataset("points", (num_test * num_points, 3), dtype='f2')
test_distances = test_f.create_dataset("sdf", (num_test * num_points), dtype='f2')
test_indices = test_f.create_dataset("indices", (num_test,), dtype='i4')
test_labels = test_f.create_dataset("labels", (num_test,), dtype='i4')

for i, (points, sdf, idx) in enumerate(dataset):
    if i < num_train:
        train_coordinates[i * num_points : (i+1) * num_points] = points.reshape(-1, 3)
        train_distances[i * num_points : (i+1) * num_points] = sdf.reshape(-1)
        train_indices[i] = idx
        train_labels[i] = dataset.metadata[dataset.models[i]["category"]]["idx"]
    else:
        test_coordinates[(i-num_train) * num_points : (i+1-num_train) * num_points] = points.reshape(-1, 3)
        test_distances[(i-num_train) * num_points : (i+1-num_train) * num_points] = sdf.reshape(-1)
        test_indices[i - num_train] = idx
        test_labels[i - num_train] = dataset.metadata[dataset.models[i]["category"]]["idx"]

train_f.close()
test_f.close()
