import os
import wandb
import joblib

import optuna
from optuna.samplers import TPESampler, GridSampler
from optuna.integration.wandb import WeightsAndBiasesCallback

import tensorflow as tf

from braivest.train.Trainer import Trainer
from braivest.utils import load_data

from braivest.model.emgVAE import emgVAE
from sklearn.model_selection import train_test_split

import argparse
parser = argparse.ArgumentParser()

parser.add_argument('--artifact_dir', type=str, default='/scratch/gpfs/tt1131/projects/neighbor_vae_experiments/dataset/synthetic_spiral')
parser.add_argument('--project', type=str, default='neighbor_vae_della_spiral')
parser.add_argument('--save_dir', type=str)
parser.add_argument('--n_trials', type=int)
parser.add_argument('--repeat', type=int)
parser.add_argument('--seed', type=int, default=42)

args = parser.parse_args()
ARTIFACT_DIR = args.artifact_dir
PROJECT_NAME = args.project
SAVE_DIR = args.save_dir
N_TRIALS = args.n_trials
REPEAT = args.repeat
SEED = args.seed

MODEL_CONFIG = {
    'num_layers': 2,
    'layer_dims': 250,
    'batch_size': 10000,
    'latent_dim': 2,
    'lr': 1e-3,
    'nw': 0,
    'kl': 1e-5,
    'time': True,
    'emg': False,
    'save_best': False,
    'epochs': 500,
    'metric': 'mse',
    'val_size': 0.2,
    'seed': SEED
}

def objective(trial):
    ## Load data
    if MODEL_CONFIG['time']:
        train_X = load_data(ARTIFACT_DIR, 'train.npy')
        train_Y = load_data(ARTIFACT_DIR, 'train_Y.npy')
        train_set = (train_X, train_Y)
    else:
        train_X = load_data(ARTIFACT_DIR, 'train.npy')
        train_set = (train_X, train_X)

    input_dim = train_X.shape[1]

    ## Optuna search 5 items: n_layers, layer_dims, kl, batch_size, learning_rate
    n_layers = trial.suggest_int('n_layers', 1, 5)
    layer_dims = trial.suggest_int('layer_dims', 50, 500)
    kl = trial.suggest_float('kl', 1e-4, 1e-2, log=True)
    batch_size = trial.suggest_int('batch_size', 32, 10000)
    learning_rate = trial.suggest_float('lr', 1e-5, 1e-2, log=True)

    x_train, x_val, y_train, y_val = train_test_split(train_set[0], train_set[1], 
                                                      test_size=MODEL_CONFIG['val_size'], 
                                                      shuffle=True, random_state=42)
    
    train_set = (tf.convert_to_tensor(x_train), tf.convert_to_tensor(y_train))
    val_set = (tf.convert_to_tensor(x_val), tf.convert_to_tensor(y_val))
    
    layers = [layer_dims for layer in range(n_layers)]
    tf.random.set_seed(MODEL_CONFIG['seed'])
    model = emgVAE(input_dim = input_dim, 
                   latent_dim = MODEL_CONFIG['latent_dim'], 
                   hidden_states = layers, 
                   kl = kl, emg = False) ## Set latent to 2

    ## Optuna search for learning rate here
    model.compile(loss='mse', 
                  optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate), 
                  metrics = MODEL_CONFIG['metric'])
    model.fit(train_set[0], train_set[1], epochs=MODEL_CONFIG['epochs'], 
              batch_size=batch_size, validation_data=val_set, verbose=False)


    val_loss, val_nl_loss, val_mse = model.evaluate(val_set[0], val_set[1], verbose=False)
    
    trial.set_user_attr('neighbor_loss', model.neighbor_loss_tracker.result().numpy())
    trial.set_user_attr('mse', model.mse.result().numpy())
    trial.set_user_attr('loss', model.loss_tracker.result().numpy())

    trial.set_user_attr('val_neighbor_loss', val_nl_loss)
    trial.set_user_attr('val_mse', val_mse)
    trial.set_user_attr('val_loss', val_loss)

    model.save_weights(f'{SAVE_DIR}/model_weights_{REPEAT}_{trial.number}.h5')
    
    return val_loss

if __name__ == '__main__':
    print(MODEL_CONFIG)
    wandb_kwargs = {"project": f"optuna_{PROJECT_NAME}", 
                    "config": MODEL_CONFIG}
    wandbc = WeightsAndBiasesCallback(wandb_kwargs=wandb_kwargs)

    search_space = {
        'n_layers': [2, 3, 4],
        'layer_dims': [50, 100, 300, 400],
        'kl': [1e-4, 1e-3],
        'batch_size': [128, 256, 512, 1024],
        'lr': [1e-4, 3e-4, 1e-3],
    }
    study = optuna.create_study(direction="minimize", 
                                sampler=GridSampler(search_space, 42))
    
    study.optimize(objective, callbacks = [wandbc])

    joblib.dump(study, f"{SAVE_DIR}/study_r_{REPEAT}.pkl")

    print("Number of finished trials: {}".format(len(study.trials)))

    print("Best trial:")
    trial = study.best_trial

    print("  Value: {}".format(trial.value))
    print("  Params: ")
    for key, value in trial.params.items():
        print("    {}: {}".format(key, value))