import wandb
import numpy as np
import os
from generate_data_utils import *
import ssm
import pickle

with wandb.init(project="neighbor_vae", job_type="generate_data") as run:
	#artifact_dir = "artifacts/probe12_subject0_test0:v0"
	artifact = run.use_artifact("juliahwang/lfp_VAE/probe12_subject0_test0:v0", type='dataset')
	artifact_dir = artifact.download()
	train = np.load(os.path.join(artifact_dir, 'train.npy'))
	val_hypno = np.load(os.path.join(artifact_dir, 'hypno.npy'))[0]
	val = np.load(os.path.join(artifact_dir, 'test.npy'))

	train_size = train.shape[0]
	val_size = val.shape[0]

	#estimate mu, sigma for REM, SWS, Wake
	mus = np.zeros((3, 31))
	sigmas = np.zeros((3, 31))
	i = 0
	for cluster in np.unique(val_hypno[:3]):
		cluster_data = val[:len(val_hypno)][val_hypno == cluster]
		mus[i, :] = np.mean(cluster_data, axis=0)
		sigmas[i, :] = np.std(cluster_data, axis=0)
		i += 1
	
	hmm_low = pickle.load(open("hmm_3.p", 'rb'))
	hmm = ssm.HMM(3, 31, observations="gaussian", transitions=hmm_low.transitions)
	hmm.observations.mus = mus
	hmm.observations.sigmas = sigmas
	train_hypno, train = hmm.sample(train_size)
	train_X = train[:-1]
	train_Y = train[1:]
	test_hypno, test = hmm.sample(val_size)
	raw_data = wandb.Artifact(
				"synthetic_hmm", type="dataset",
				description="synthetic data modeled from visual probe mouse 0 sampled from an HMM",
				metadata={"clusters":3,
	      				"dimensions": 31,
						"method": "HMM"})

	datasets = (train_X, train_Y, train_hypno, test, test_hypno,)
	names = ['train','train_Y', 'train_hypno', 'test', 'test_hypno' ]
	for temp, file_name in zip(datasets, names):
		with raw_data.new_file(file_name + ".npy", mode="wb") as file:
			np.save(file, temp)
	run.log_artifact(raw_data)

