import numpy as np

category_list=["charges", "vel", "loc", "edges"]
n = 500  # change this value to select a different number of samples

for category in category_list:
   
    original_data_path = f'/atlas2/u/wanjiazh/CGeoDM/datasets/datagen/{category}_train_charged5_initvel1.npy'
    data = np.load(original_data_path)
    print(f"Shape of {category} data:", data.shape)
    
    indices = np.random.choice(data.shape[0], n, replace=False)
    subset = data[indices]
    subset_data_path2 = original_data_path.replace('train_charged5_initvel1', f'train{n}_charged5_initvel1')
    subset_data_path = subset_data_path2.replace('datagen', "partial_data")
    np.save(subset_data_path, subset)
    print(f"Subset saved to {subset_data_path} with shape:", subset.shape)