import numpy as np
import torch
from sklearn.model_selection import train_test_split
import pickle

np.random.seed(42)
pre_idx = list(range(100))
idx = []
train_idx = []
test_idx = []

for i in range(len(pre_idx)):
    j = pre_idx[i]
    idx_i = list(np.random.choice(99, 80, replace=False))
    idx += list(range(j * 99, (j + 1) * 99))
    train_idx += [jj + 99 * j for jj in idx_i]
test_idx = [item for item in idx if item not in train_idx]
test_idx, val_idx = train_test_split(test_idx, train_size=0.5, shuffle=True, random_state=42)

print(len(train_idx))
print(len(test_idx))
print(len(val_idx))

with open('./indices/indices_100_80_10_10.pkl', 'wb') as f:
    pickle.dump({'train_idx': train_idx, 'test_idx': test_idx, 'val_idx': val_idx}, f)