import numpy as np
import pandas as pd
from scipy.stats import skew, kurtosis
import os
from sklearn import metrics
import wandb
import joblib
from braivest.model.emgVAE import emgVAE
import tensorflow as tf

def model_encode(model_path, trial_number, repeat_id, input_dim, study_config, test):
    n_layers = int(study_config['n_layers'])
    layer_dims = int(study_config['layer_dims'])
    
    layers = [layer_dims for layer in range(n_layers)]
    model = emgVAE(input_dim = input_dim, latent_dim = 2, 
                   hidden_states = layers, kl = study_config['kl'], 
                   emg = False)

    model.build((None, input_dim))
    model.load_weights(os.path.join(model_path, f'model_weights_{repeat_id}_{trial_number}.h5'))
    encoded = model.encode(test, numpy=True)
    tf.keras.backend.clear_session()
    del model    
    return encoded

def get_skew_and_kurtosis(encodings, hypno):
    ks = []
    sks = []
    hypno = np.asarray(hypno)
    hypno_unique = np.unique(hypno)
    for label in hypno_unique:
        label_encodings = encodings[hypno == label, :]
        ks.append(kurtosis(label_encodings, axis=0))
        sks.append(np.abs(skew(label_encodings, axis=0)))
    return np.mean(sks), np.mean(ks)

runs_path = '/mnt/home/tt1131/neighbor_vae_expts/runs_and_models/vae_hmm/runs15-18/run16_VAE_real/models/'
study = joblib.load(os.path.join(runs_path, f"study_r_0.pkl"))
artifact_dir = 'artifacts/probe12_subject0_test0:v0'
test = np.load(os.path.join(artifact_dir, 'test.npy'))
test_hypno = np.load(os.path.join(artifact_dir, 'hypno.npy'))[0]

all_hp_labels = ['n_layers', 'layer_dims', 'kl', 'batch_size', 'lr']
all_metric_labels = ['loss', 'val_loss', 'neighbor_loss', 
                     'val_neighbor_loss', 'mse', 'val_mse', 'silhouette', 'skew', 'kurtosis']

all_labels = np.array(all_hp_labels + all_metric_labels)
all_runs_params_list = []
metrics_df = pd.DataFrame(columns=all_labels)

for trial in study.trials:
    metrics_df.loc[trial.number,'run_idx'] = trial.number
    metrics_df.loc[trial.number, trial.params.keys()] = trial.params.values()
    metrics_df.loc[trial.number, trial.user_attrs.keys()] = trial.user_attrs.values()
    encodings = model_encode(runs_path, trial.number, 0, 31, trial.params, test)
    metrics_df.loc[trial.number, 'silhouette'] = metrics.silhouette_score(encodings[:len(test_hypno)], test_hypno, n_jobs=-1, sample_size=25000)
    sk, k = get_skew_and_kurtosis(encodings[:len(test_hypno)], test_hypno)
    metrics_df.loc[trial.number, 'skew'] = sk
    metrics_df.loc[trial.number, 'kurtosis'] = k
metrics_df.to_csv('metrics_real_vanilla')
