import numpy as np
import pandas as pd
import os
from sklearn import metrics
import wandb
import pickle
import joblib
import seaborn as sns
import matplotlib.pyplot as plt
import matplotlib
matplotlib.rcParams['pdf.fonttype'] = 42

import numpy as np
import pandas as pd
import os
from sklearn import metrics
import wandb
import joblib
from braivest.model.emgVAE import emgVAE
import tensorflow as tf
from sklearn.linear_model import LinearRegression
from scipy.stats import linregress

def model_encode(model_path, trial_number, repeat_id, input_dim, study_config, test):
    n_layers = int(study_config['n_layers'])
    layer_dims = int(study_config['layer_dims'])
    
    layers = [layer_dims for layer in range(n_layers)]
    model = emgVAE(input_dim = input_dim, latent_dim = 2, 
                   hidden_states = layers, kl = study_config['kl'], 
                   emg = False)

    model.build((None, input_dim))
    model.load_weights(os.path.join(model_path, f'model_weights_{repeat_id}_{trial_number}.h5'))
    encoded = model.encode(test, numpy=True)
    tf.keras.backend.clear_session()
    del model    
    return encoded

def continuous_silhouette(encodings, bin_size = 100):
    n_bins = len(encodings)/bin_size 
    labels = np.repeat(np.arange(n_bins), bin_size)
    return metrics.silhouette_score(encodings, labels)

runs_path = '/mnt/home/tt1131/neighbor_vae_expts/runs_and_models/spiral/run21_spiral_TNVAE/models'

artifact_dir = '../artifacts/synthetic_spiral:v4'
test = np.load(os.path.join(artifact_dir, 'test_datasets.npy'))[-2]
test_low = np.load(os.path.join(artifact_dir, 'test_low.npy'))

all_hp_labels = ['n_layers', 'layer_dims', 'kl', 'batch_size', 'lr']
all_metric_labels = ['loss', 'val_loss', 'neighbor_loss', 
                     'val_neighbor_loss', 'mse', 'val_mse']

all_labels = np.array(all_hp_labels + all_metric_labels)
all_runs_params_list = []
metrics_df = pd.DataFrame(columns=all_labels)

for study_idx in range(3):
    study = joblib.load(os.path.join(runs_path, "study_r_{}.pkl".format(study_idx)))
    for trial in study.trials:
        new_row_idx = len(metrics_df.index)
        metrics_df.loc[new_row_idx, 'study_idx'] = study_idx
        metrics_df.loc[new_row_idx,'run_idx'] = trial.number
        metrics_df.loc[new_row_idx, trial.params.keys()] = trial.params.values()
        metrics_df.loc[new_row_idx, trial.user_attrs.keys()] = trial.user_attrs.values()
        encodings = model_encode(runs_path, trial.number, study_idx, 31, trial.params, test)
        metrics_df.loc[new_row_idx, 'silhouette']  = continuous_silhouette(encodings)

metrics_df.to_csv('metrics_spiral')
