import numpy as np
import pickle

from utils.make_data import create_local_data
np.random.seed(0)

data_directory = 'data/'
selector_prob = [25, 15, 10, 50]
datatypes = ['switch_' + str(selector_prob)]
for datatype in datatypes:
    print('saving synthetic data of type: ' + datatype)
    x_train, y_train, x_val, y_val, datatype_val, datatype_train, selector_label_train,\
    selector_label_val, input_shape = \
        create_local_data(selector_prob=np.array(selector_prob)/100, n=int(1e6))
    data_dict = {'x_train': x_train,
                 'y_train': y_train,
                 'x_val': x_val,
                 'y_val': y_val,
                 'datatype_val': datatype_val,
                 'datatype_train': datatype_train,
                 'selector_label_train': selector_label_train,
                 'selector_label_val': selector_label_val,
                 'input_shape': input_shape}
    data_file = open(data_directory + datatype + 'gt.pk', 'wb')
    pickle.dump(data_dict, data_file)
    data_file.close()
