import errno
import os
import random
import pickle
import numpy as np
import yaml
import itertools

config_file = './../env.yml'
with open(config_file, 'r') as stream:
    yamlfile = yaml.safe_load(stream)
    root_dir = yamlfile['root_dir']
    src_dir = yamlfile['src_dir']


def main():
    DATASET_PATH = os.path.join(root_dir, 'tinyimagenet', 'data')
    train_data = np.load(os.path.join(DATASET_PATH, 'partition', 'train_data.npy'))
    train_label = np.load(os.path.join(DATASET_PATH, 'partition', 'train_label.npy'))
    # half the train data
    temp_list = [[] for _ in range(200)]
    temp_list_2 = [[] for _ in range(200)]
    for i in range(len(train_label)):
        if len(temp_list[train_label[i]]) < 250:
            temp_list[train_label[i]] += [i]
        else:
            temp_list_2[train_label[i]] += [i]
    idx_list = list(itertools.chain(*temp_list))
    idx_list_2 = list(itertools.chain(*temp_list_2))
    np.random.shuffle(idx_list)
    np.random.shuffle(idx_list_2)
    train_data1 = train_data[idx_list]
    train_label1 = train_label[idx_list]
    train_data2 = train_data[idx_list_2]
    train_label2 = train_label[idx_list_2]
    print(idx_list[:20])
    print(idx_list_2[:20])
    np.save(os.path.join(DATASET_PATH, 'partition', 'train_data_half1.npy'), train_data1)
    np.save(os.path.join(DATASET_PATH, 'partition', 'train_label_half1.npy'), train_label1)
    np.save(os.path.join(DATASET_PATH, 'partition', 'train_data_half2.npy'), train_data2)
    np.save(os.path.join(DATASET_PATH, 'partition', 'train_label_half2.npy'), train_label2)


if __name__ == '__main__':
    main()
