
import numpy as np
import matplotlib.pyplot as plt
import pdb

sigma_1 = 0.1
sigma_2 = 0.1
n_train = 4000
n_val = 2000
n_test = 2000

mean_1 = [0, 1]
cov_1 = [[sigma_1**2, 0], [0, sigma_2**2]]

mean_2 = [1, 1]
cov_2 = [[sigma_1**2, 0], [0, sigma_2**2]]

#################### train $$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$
x_train_1 = np.random.multivariate_normal(mean_1, cov_1, int(n_train/2))
x_train_2 = np.random.multivariate_normal(mean_2, cov_2, int(n_train/2))
x_train = np.concatenate((x_train_1, x_train_2), axis=0)
y_train_1 = 1*np.ones(int(n_train/2))
y_train_2 = -1*np.ones(int(n_train/2))
y_train = np.concatenate((y_train_1, y_train_2), axis=0)

data_train = np.zeros((n_train, 3))
data_train[:, 0:2] = x_train
data_train[:, 2] = y_train

# pdb.set_trace()

########################### validation #######################################
x_val_1 = np.random.multivariate_normal(mean_1, cov_1, int(n_val/2))
# x_val_1[:, 0] = np.sin(np.pi * x_val_1[:, 0] + np.pi / 2 * np.ones(x_val_1[:, 0].shape))
x_val_2 = np.random.multivariate_normal(mean_2, cov_2, int(n_val/2))
# x_val_2[:, 0] = np.sin(np.pi * x_val_2[:, 0] + np.pi / 2 * np.ones(x_val_2[:, 0].shape))
x_val = np.concatenate((x_val_1, x_val_2), axis=0)
y_val_1 = 1*np.ones(int(n_test/2))
y_val_2 = -1*np.ones(int(n_test/2))
y_val = np.concatenate((y_val_1, y_val_2), axis=0)

data_val = np.zeros((n_test, 3))
data_val[:, 0:2] = x_val
data_val[:, 2] = y_val

############################## testing ###################################3333333
x_test_1 = np.random.multivariate_normal(mean_1, cov_1, int(n_test/2))
x_test_2 = np.random.multivariate_normal(mean_2, cov_2, int(n_test/2))
x_test = np.concatenate((x_test_1, x_test_2), axis=0)
y_test_1 = 1*np.ones(int(n_test/2))
y_test_2 = -1*np.ones(int(n_test/2))
y_test = np.concatenate((y_test_1, y_test_2), axis=0)

data_test = np.zeros((n_test, 3))
data_test[:, 0:2] = x_test
data_test[:, 2] = y_test
######################################################################3
plt.scatter(x_train_1[:, 0], x_train_1[:, 1], c='r', marker='H', linewidth=1)
plt.scatter(x_train_2[:, 0], x_train_2[:, 1], c='b', marker='x', linewidth=1)

plt.axis('equal')
plt.show()

plt.scatter(x_val_1[:, 0], x_val_1[:, 1], c='r', marker='H', linewidth=1)
plt.scatter(x_val_2[:, 0], x_val_2[:, 1], c='b', marker='x', linewidth=1)

plt.axis('equal')
plt.show()

plt.scatter(x_test_1[:, 0], x_test_1[:, 1], c='r', marker='H', linewidth=1)
plt.scatter(x_test_2[:, 0], x_test_2[:, 1], c='b', marker='x', linewidth=1)
plt.axis('equal')
plt.show()

####################### Saving Data ####################
pdb.set_trace()
np.save('./data_train.npy', data_train)
np.save('./data_valid.npy', data_val)
np.save('./data_test.npy', data_test)
# pdb.set_trace()
# data_test[:, 0] = np.sin(np.pi * data_test[:, 0] + np.pi / 2 * np.ones((n_test, )))
np.savetxt('./data_test.out', data_test[:, 0:2], fmt='%10.5f')


