import numpy as np
import pickle

from utils.make_data import create_data
np.random.seed(0)

data_directory = 'data/'
datatypes = ['orange_skin', 'XOR', 'nonlinear_additive', 'switch']
for datatype in datatypes:
    print('saving synthetic data of type: ' + datatype)
    x_train, y_train, x_val, y_val, datatype_val, datatype_train, input_shape = create_data(datatype, n=int(1e6))
    data_dict = {'x_train': x_train,
                 'y_train': y_train,
                 'x_val': x_val,
                 'y_val': y_val,
                 'datatype_val': datatype_val,
                 'datatype_train': datatype_train,
                 'input_shape': input_shape}
    data_file = open(data_directory + datatype + 'gt.pk', 'wb')
    pickle.dump(data_dict, data_file)
    data_file.close()
