import os
import pickle
import numpy as np
from collections import defaultdict
from sklearn.datasets import fetch_rcv1

np.random.seed(1024)
rcv1 = fetch_rcv1()

X = rcv1.data
y = rcv1.target
os.makedirs('./rcv1', exist_ok=True)
n = 50000
k = 10
tot = X.shape[0]

index = np.arange(tot)
np.random.shuffle(index)
X = X[index, :]
y = y[index, :]
y = y.toarray()
print(n, k)

lbl_count = np.sum(y, axis=0)
lbl_index = np.argsort(lbl_count)[::-1][10:20]
np.random.shuffle(lbl_index)
y = y[:, lbl_index]
lbl_count = np.sum(y, axis=0)
print(lbl_count)

ratio_max = 1.0 / 5
n_max = int(n * ratio_max)
print(n_max)
select = []
cnt = defaultdict(int)
for i in range(tot):
	yi = y[i, :]
	if np.sum(yi) == 0:
		continue
	flag = True
	for j in range(k):
		if cnt[j] + yi[j] > n_max:
			flag = False
	if flag:
		select.append(i)
		for j in range(k):
			cnt[j] += yi[j]
		if len(select) == n:
			break

y_new = y[select, :]
x_new = X[select, :]
data_info = {}
data_info['X'] = x_new
data_info['y'] = y_new
f = open('./rcv1/data_full_10class.gt', 'wb')
pickle.dump(data_info, f)

lbl_count = np.sum(y_new, axis=0)
print(lbl_count)

# f = open('./rcv1/data_full_10class.gt', 'rb')
# data = pickle.load(f)
# X, y = data['X'], data['y']
# label_count = np.sum(y, axis=0)
# print(np.min(label_count), np.max(label_count))



