import numpy as np
from sklearn.model_selection import train_test_split
import wandb
SEED = 32

with wandb.init(project="neighbor-vae", entity='engellab', job_type="generate_data") as run:
	r = np.linspace(0,20,12000)
	t = np.linspace(0,500,12000)
	x = r*np.cos(np.radians(t))
	y = r*np.sin(np.radians(t))
	np.random.seed(SEED)
	spiral_true = np.asarray([x, y]).T

	noise_mean = 0
	noise_std = 0.2
	train_indices = np.sort(np.random.choice(12000, size=8000, replace=False))
	mask = np.zeros(spiral_true.shape[0], dtype=bool)
	mask[train_indices] = True
	train_low = spiral_true[mask, :]
	test_low = spiral_true[~mask, :]
	print(train_low.shape)
	print(test_low.shape)
	project_std = 2
	project = np.random.normal(noise_mean, project_std, (2, 31))
	noise_train = np.random.normal(noise_mean, noise_std, (len(train_low), 31))
	train = 1/(1 + np.exp(-1*(train_low.dot(project)))) + noise_train
	train_X = train[:-1]
	train_Y = train[1:]

	test_stds = [0, 0.05, 0.1, 0.2, 0.4]
	test_datasets = []
	for std in test_stds:
		noise_test = np.random.normal(noise_mean, std, (len(test_low), 31))
		test = 1/(1 + np.exp(-1*(test_low.dot(project)))) + noise_test
		test_datasets.append(test)

	raw_data = wandb.Artifact(
				"synthetic_spiral", type="dataset",
				description="synthetic data of a spiral",
				metadata={"radius":20,
	      					"dimension":31,
							"method": "spiral",
							"nonlinearity": "sigma",
							"seed": SEED})

	datasets = (train_X, train_Y,  test_datasets,  test_low, spiral_true, project)
	names = ['train','train_Y', 'test_datasets', 'test_low', 'spiral_true', 'project']
	for temp, file_name in zip(datasets, names):
		with raw_data.new_file(file_name + ".npy", mode="wb") as file:
			np.save(file, temp)
	run.log_artifact(raw_data)