import numpy as np
import pandas as pd
from scipy.stats import skew, kurtosis
import os
from sklearn import metrics
import wandb

def get_skew_and_kurtosis(encodings, hypno):
    ks = []
    sks = []
    hypno = np.asarray(hypno)
    hypno_unique = np.unique(hypno)
    for label in hypno_unique:
        label_encodings = encodings[hypno == label, :]
        ks.append(kurtosis(label_encodings, axis=0))
        sks.append(np.abs(skew(label_encodings, axis=0)))
    return np.mean(sks), np.mean(ks)

metrics_hmm = pd.read_csv('/mnt/home/tt1131/neighbor_vae_expts/results/metrics_hp_hmm_091723.npy')
encodings_hmm = np.load('/mnt/home/tt1131/neighbor_vae_expts/results/encodings_per_hp_hmm_091723.npy', allow_pickle=True).item()
run = wandb.init()
artifact = run.use_artifact('engellab/neighbor-vae/synthetic_hmm:v0', type='dataset')
artifact_dir = artifact.download()
test_hypno = np.load(os.path.join(artifact_dir, 'test_hypno.npy'))
for i in range(metrics_hmm.shape[0]):
    query = list(metrics_hmm.loc[i, ['n_layers','layer_dims', 'kl', 'batch_size', 'lr']])
    query_str = "_".join(str(s) for s in query[:5])
    encoding_repeats = encodings_hmm[query_str]
    silhouettes = []
    skews = []
    kurtosises = []
    for repeat in encoding_repeats:
        silhouettes.append(metrics.silhouette_score(repeat, test_hypno))
        sk, k = get_skew_and_kurtosis(repeat, test_hypno)
        skews.append(sk)
        kurtosises.append(k)
    metrics_hmm.loc[i, 'silhouette mean'] = np.mean(silhouettes)
    metrics_hmm.loc[i, 'silhouette std'] = np.std(silhouettes)
    metrics_hmm.loc[i,'skew mean'] = np.mean(skews)
    metrics_hmm.loc[i,'skew std'] = np.std(skews)
    metrics_hmm.loc[i,'kurtosis mean'] = np.mean(kurtosises)
    metrics_hmm.loc[i,'kurtosis std'] = np.std(kurtosises)
metrics_hmm.to_csv('metrics_hmm')
