import numpy as np

data = np.load("../data/cifar5m/cifar5m_part0.npz", mmap_mode='r')
sampled_num = 500000
class_num = 10
sampled_per_class = int(sampled_num/class_num)

x = data["X"]
y = data["Y"]
sampled_id = np.random.permutation(y.shape[0])
sampled_y = y[sampled_id]
sampled_x = x[sampled_id]
from collections import Counter
print(Counter(sampled_y))

id_list = []
for i in range(10): id_list.append(np.where(sampled_y==i)[0][:sampled_per_class])

id_sorted = np.vstack(id_list).transpose().reshape(sampled_num)

x_sorted = sampled_x[id_sorted]
y_sorted = sampled_y[id_sorted]

np.save("../data/cifar5m/cifar5m_sampled_x_500k.npy", x_sorted)
np.save("../data/cifar5m/cifar5m_sampled_y_500k.npy", y_sorted)
