{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os, warnings\n",
    "os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'  # or any {'0', '1', '2'}\n",
    "warnings.filterwarnings('ignore') "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "from random import shuffle\n",
    "import sys, os\n",
    "from datetime import datetime, timedelta\n",
    "import numpy as np , pandas as pd\n",
    "import time\n",
    "import joblib\n",
    "import random\n",
    "import multiprocessing\n",
    "\n",
    "\n",
    "from sklearn.model_selection import KFold\n",
    "from skopt import gp_minimize\n",
    "from skopt.space import Real, Categorical, Integer\n",
    "from skopt.plots import plot_convergence\n",
    "from skopt.plots import plot_objective, plot_evaluations\n",
    "from skopt.utils import use_named_args\n",
    "from skopt import Optimizer # for the optimization\n",
    "from joblib import Parallel, delayed # for the parallelization\n",
    "import scipy\n",
    "\n",
    "\n",
    "import tensorflow as tf\n",
    "from tensorflow.keras.optimizers import Adam\n",
    "from tensorflow.keras.callbacks import EarlyStopping, ReduceLROnPlateau\n",
    "from vae_dense_model import VariationalAutoencoderDense as VAE_Dense\n",
    "from vae_conv_model import VariationalAutoencoderConv as VAE_Conv\n",
    "from vae_conv_I_model import VariationalAutoencoderConvInterpretable as VAE_ConvI\n",
    "from config import config as cfg\n",
    "import utils"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Define paths"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "input_dir = \"../../data/processed_orig_data/\"\n",
    "output_dir = \"../../data/generated_data/\"\n",
    "model_dir = './model/'\n",
    "log_dir = './log/'\n",
    "hpo_dir = './hpo_results/'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "sys.path.insert(0, '../../evaluations/disc_and_preds/metrics/')\n",
    "\n",
    "from discriminative_metrics3 import discriminative_score_metrics\n",
    "from predictive_metrics3 import predictive_score_metrics\n",
    "from visualization_metrics import visualization"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Utility Functions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Set seed for reproducibility\n",
    "def set_seeds(seed_value):   \n",
    "    os.environ['PYTHONHASHSEED']=str(seed_value)\n",
    "    random.seed(seed_value)\n",
    "    np.random.seed(seed_value)\n",
    "    tf.random.set_seed(seed_value)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_train_valid_split(data, valid_perc):\n",
    "    N = data.shape[0]\n",
    "    N_train = int(N * (1 - valid_perc))\n",
    "    N_valid = N - N_train\n",
    "\n",
    "    # shuffle data, just in case\n",
    "    np.random.shuffle(data)\n",
    "\n",
    "    # train, valid split \n",
    "    train_data = data[:N_train]\n",
    "    valid_data = data[N_train:]\n",
    "    return train_data, valid_data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_scaler(scaling_method): \n",
    "    scaler = None\n",
    "    if scaling_method == 'minmax':    \n",
    "        scaler = utils.MinMaxScaler( )  \n",
    "    elif scaling_method == 'standard': \n",
    "        raise NotImplementedError(f'Scaling method {scaling_method} not implemented')\n",
    "    elif scaling_method == 'yeojohnson':\n",
    "        raise NotImplementedError(f'Scaling method {scaling_method} not implemented')\n",
    "    else:         \n",
    "        raise NotImplementedError(f'Scaling method {scaling_method} not implemented')  \n",
    "        \n",
    "    return scaler"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "def scale_train_valid_data(train_data, valid_data, scaling_method):         \n",
    "         \n",
    "          \n",
    "    scaled_train_data = scaler.fit_transform(train_data)\n",
    "    scaled_valid_data = scaler.transform(valid_data)\n",
    "    return scaled_train_data, scaled_valid_data, scaler"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Main VAE Train and Evaluate Function"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "def train_model(train_data, latent_dim, n_layer1_in_25s, n_layer2_in_25s, n_layer3_in_50s, \n",
    "                reconstruction_wt, epochs = 100):\n",
    "    \n",
    "    _, T, D = train_data.shape\n",
    "\n",
    "    # ----------------------------------------------------------------------------------------------\n",
    "    # Instantiate the VAE\n",
    "    vae = VAE_ConvI( seq_len=T,  \n",
    "                    feat_dim = D, \n",
    "                    latent_dim = int(latent_dim), \n",
    "                    hidden_layer_sizes=[ \n",
    "                        int(n_layer1_in_25s*25), \n",
    "                        int(n_layer2_in_25s*25),\n",
    "                        int(n_layer3_in_50s*50)], \n",
    "                    reconstruction_wt = reconstruction_wt, \n",
    "                # trend_poly=1, \n",
    "                # num_gen_seas=1,\n",
    "                # custom_seas = [ (7, 1)] ,     # list of tuples of (num_of_seasons, len_per_season)\n",
    "                use_residual_conn = True\n",
    "        )\n",
    "\n",
    "    vae.compile(optimizer=Adam())\n",
    "    # vae.summary() ; sys.exit()\n",
    "\n",
    "    # ----------------------------------------------------------------------------------------------\n",
    "    # Train the VAE\n",
    "    early_stop_loss = 'loss'\n",
    "    early_stop_callback = EarlyStopping(monitor=early_stop_loss, min_delta = 1e-1, patience=50) \n",
    "    reduceLR = ReduceLROnPlateau(monitor='loss', factor=0.1, patience=10)\n",
    "\n",
    "    history = vae.fit(\n",
    "        train_data, \n",
    "        batch_size = 32,\n",
    "        epochs=epochs,\n",
    "        shuffle = True,\n",
    "        callbacks=[early_stop_callback, reduceLR],\n",
    "        verbose = 0\n",
    "    )\n",
    "    \n",
    "    # ----------------------------------------------------------------------------------------------\n",
    "    return vae, history"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "def generate_and_save_samples(vae, num_samples, scaler):\n",
    "    samples = vae.get_prior_samples(num_samples= num_samples)\n",
    "#     print(\"gen sample size: \", samples.shape)\n",
    "\n",
    "    # inverse transform using scaler \n",
    "    samples = scaler.inverse_transform(samples)        \n",
    "\n",
    "    # save to output dir\n",
    "    samples_fpath = f'{model}/{model}_gen_samples_{data_name}_perc_{p}.npz'        \n",
    "    np.savez_compressed(os.path.join( output_dir, samples_fpath), data=samples)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "def evaluate_model(model, valid_data): \n",
    "    return model.evaluate(valid_data, verbose = 0, return_dict=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "def confidence_interval(data, confidence=0.95):\n",
    "    a = 1.0 * np.array(data)\n",
    "    n = len(a)\n",
    "    m, se = np.mean(a), scipy.stats.sem(a)\n",
    "    h = se * scipy.stats.t.ppf((1 + confidence) / 2., n-1)\n",
    "    return m, h"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "def evaluate_samples(scaled_ori_data, scaled_gen_data, predictor = 'conv', pred_epochs = 500, disc_epochs = 500):\n",
    "    \n",
    "    #orig_data, generated_data, epochs = 2500, predictor = 'conv'\n",
    "    \n",
    "    predictive_score = list(); discriminative_score = list()\n",
    "    for tt in range(metric_iteration):\n",
    "        pred_score = predictive_score_metrics(scaled_ori_data, scaled_gen_data, \n",
    "                                             predictor = 'conv', # conv, rnn, nbeats\n",
    "                                             epochs = pred_epochs, )        \n",
    "        \n",
    "        predictive_score.append(pred_score)\n",
    "        \n",
    "#         disc_score = discriminative_score_metrics(scaled_ori_data, scaled_gen_data,  epochs = disc_epochs)\n",
    "        disc_score = -1\n",
    "        discriminative_score.append(disc_score)\n",
    "        \n",
    "#         print(\"iter: \", tt, 'pred_score', pred_score) \n",
    "        \n",
    "    pred_mean = np.round(np.mean(predictive_score), 4)\n",
    "    pred_CI = np.round(confidence_interval(predictive_score)[1], 4)\n",
    "    \n",
    "    disc_mean = np.round(np.mean(discriminative_score), 4)\n",
    "    disc_CI = np.round(confidence_interval(discriminative_score)[1], 4)\n",
    "    \n",
    "    print('Predictive score: ', pred_mean, \"+/-\", pred_CI)  \n",
    "    print('Discriminative score: ', disc_mean, \"+/-\", disc_CI)  \n",
    "    \n",
    "    return pred_mean, disc_mean"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "def evaluate_samples2(scaled_ori_data, scaled_gen_data, predictor = 'conv', pred_epochs = 500, disc_epochs = 500):\n",
    "    \n",
    "    pred_score = predictive_score_metrics(scaled_ori_data, scaled_gen_data, \n",
    "                                         predictor = 'conv', # conv, rnn, nbeats\n",
    "                                         epochs = pred_epochs, )      \n",
    "\n",
    "    # disc_score = discriminative_score_metrics(scaled_ori_data, scaled_gen_data,  epochs = disc_epochs)\n",
    "    disc_score = -1    \n",
    "    return pred_score, disc_score"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# HPO Using Scikit Optimize"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Hyperparameter Space"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 100,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[8, 2, 2, 4, 'minmax', 1.5]"
      ]
     },
     "execution_count": 100,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# determine the hyperparameter space\n",
    "param_grid = [\n",
    "    Integer(2, 10, name=\"latent_dim\"),    #1, 10\n",
    "    \n",
    "#     Integer(2, 3, name=\"n_layer1_in_25s\"),  #1, 5\n",
    "#     Integer(3, 4, name=\"n_layer2_in_25s\"),  #1, 8\n",
    "#     Integer(3, 4, name=\"n_layer3_in_50s\"),  #1, 8\n",
    "    Categorical([2], name=\"n_layer1_in_25s\"),\n",
    "    Categorical([2], name=\"n_layer2_in_25s\"),\n",
    "    Categorical([4], name=\"n_layer3_in_50s\"),\n",
    "    \n",
    "    Categorical(['minmax'], name=\"scaling_method\"),\n",
    "    Real(0.5, 4.5, prior='uniform', name='reconstruction_wt'),\n",
    "]\n",
    "\n",
    "dim_names = [\n",
    "    'latent_dim',\n",
    "    'n_layer1_in_25s',\n",
    "    'n_layer2_in_25s',\n",
    "    'n_layer3_in_50s',\n",
    "    'scaling_method',\n",
    "    'reconstruction_wt',\n",
    "]\n",
    "\n",
    "default_parameters = [\n",
    "    8,              # latent_dim\n",
    "    2,              # n_layer1_in_25s \n",
    "    2,              # n_layer2_in_25s\n",
    "    4,              # n_layer3_in_50s\n",
    "    'minmax',        # scaling_method\n",
    "    1.5,        # reconstruction_wt\n",
    "]\n",
    "\n",
    "\n",
    "\n",
    "default_parameters"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Objective for HPO"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 103,
   "metadata": {},
   "outputs": [],
   "source": [
    "@use_named_args(param_grid)\n",
    "def objective(\n",
    "            latent_dim,\n",
    "            n_layer1_in_25s,\n",
    "            n_layer2_in_25s,\n",
    "            n_layer3_in_50s,\n",
    "            scaling_method,\n",
    "            reconstruction_wt,\n",
    "        ):\n",
    "\n",
    "    \n",
    "    start = time.time()  \n",
    "    \n",
    "    ## scale orig \n",
    "    scaler = get_scaler(scaling_method)\n",
    "    scaled_ori_data = scaler.fit_transform(ori_data)   \n",
    "    \n",
    "    \n",
    "    print(\"Training model\")\n",
    "    # train model \n",
    "    model, history = train_model(scaled_ori_data, \n",
    "            latent_dim, int(n_layer1_in_25s), int(n_layer2_in_25s), int(n_layer3_in_50s), reconstruction_wt,\n",
    "            epochs = 2000)\n",
    "\n",
    "    # generate and save samples\n",
    "    scaled_gen_data = model.get_prior_samples(num_samples= scaled_ori_data.shape[0])    \n",
    "#     print('orig and gen shapes', scaled_ori_data.shape, scaled_gen_data.shape)\n",
    "    \n",
    "    \n",
    "    pred_scores = list(); disc_scores = list()\n",
    "    for tt in range(metric_iteration):\n",
    "\n",
    "        # evaluate the generated samples\n",
    "        print(\"Scoring model \", tt)\n",
    "        pred_score, disc_score = evaluate_samples2(\n",
    "            scaled_ori_data, scaled_gen_data, predictor = 'conv', pred_epochs = pred_epochs, disc_epochs = disc_epochs)\n",
    "\n",
    "        print(\"Model Score\", tt, pred_score, disc_score)\n",
    "        pred_scores.append(pred_score); disc_scores.append(disc_score)\n",
    "    \n",
    "    del model, history\n",
    "\n",
    "    \n",
    "    trial_loss_mean = np.round(np.mean(pred_scores), 4)\n",
    "    trial_loss_CI = np.round(confidence_interval(pred_scores)[1], 4)\n",
    "    trial_loss_max = np.round(np.max(pred_scores), 4)\n",
    "    \n",
    "    # Print the hyper-parameters and loss   \n",
    "    print('-------------------------------------------')    \n",
    "    print(f'latent_dim: {latent_dim}')\n",
    "    print(f'n_layer1_in_25s: {n_layer1_in_25s}')\n",
    "    print(f'n_layer2_in_25s: {n_layer2_in_25s}')\n",
    "    print(f'n_layer3_in_50s: {n_layer3_in_50s}')\n",
    "    print(f'reconstruction_wt: {reconstruction_wt}')\n",
    "    print(f'scaling_method: {scaling_method}')   \n",
    "    print()   \n",
    "    print(\"all losses:\", pred_scores)\n",
    "    print(f\"trial vae loss: {trial_loss_mean} +/- {trial_loss_CI}\")\n",
    "    print(f\"Trial run time: {np.round((time.time() - start)/60.0, 2)} minutes\") \n",
    "    print('-------------------------------------------')   \n",
    "    \n",
    "    return trial_loss_mean\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Main Time VAE HPO Loop, by dataset"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Multi-threaded"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 102,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "num_cpus_to_use:  6\n",
      "--------------------------------------------------------------------------------\n",
      "dataset: stocks2, perc: 20\n",
      "num_loops: 9\n",
      "--------------------------------------------------------------------------------\n"
     ]
    },
    {
     "ename": "UnboundLocalError",
     "evalue": "local variable 'trial_num' referenced before assignment",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31m_RemoteTraceback\u001b[0m                          Traceback (most recent call last)",
      "\u001b[0;31m_RemoteTraceback\u001b[0m: \n\"\"\"\nTraceback (most recent call last):\n  File \"/opt/anaconda3/lib/python3.7/site-packages/joblib/externals/loky/process_executor.py\", line 418, in _process_worker\n    r = call_item()\n  File \"/opt/anaconda3/lib/python3.7/site-packages/joblib/externals/loky/process_executor.py\", line 272, in __call__\n    return self.fn(*self.args, **self.kwargs)\n  File \"/opt/anaconda3/lib/python3.7/site-packages/joblib/_parallel_backends.py\", line 608, in __call__\n    return self.func(*args, **kwargs)\n  File \"/opt/anaconda3/lib/python3.7/site-packages/joblib/parallel.py\", line 256, in __call__\n    for func, args, kwargs in self.items]\n  File \"/opt/anaconda3/lib/python3.7/site-packages/joblib/parallel.py\", line 256, in <listcomp>\n    for func, args, kwargs in self.items]\n  File \"/opt/anaconda3/lib/python3.7/site-packages/skopt/utils.py\", line 803, in wrapper\n    objective_value = func(**arg_dict)\n  File \"<ipython-input-101-7e80664fb155>\", line 14, in objective\nUnboundLocalError: local variable 'trial_num' referenced before assignment\n\"\"\"",
      "\nThe above exception was the direct cause of the following exception:\n",
      "\u001b[0;31mUnboundLocalError\u001b[0m                         Traceback (most recent call last)",
      "\u001b[0;32m<ipython-input-102-96cefcac7c22>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m     74\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     75\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 76\u001b[0;31m             \u001b[0my\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mParallel\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mn_jobs\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mnum_cpus_to_use\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdelayed\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mobjective\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mv\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mv\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m  \u001b[0;31m# evaluate points in parallel\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     77\u001b[0m             \u001b[0moptimizer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtell\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     78\u001b[0m             \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'\\nloop:'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mi\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'points:'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/opt/anaconda3/lib/python3.7/site-packages/joblib/parallel.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, iterable)\u001b[0m\n\u001b[1;32m   1015\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1016\u001b[0m             \u001b[0;32mwith\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_backend\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mretrieval_context\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1017\u001b[0;31m                 \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mretrieve\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m   1018\u001b[0m             \u001b[0;31m# Make sure that we get a last message telling us we are done\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1019\u001b[0m             \u001b[0melapsed_time\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtime\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtime\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m-\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_start_time\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/opt/anaconda3/lib/python3.7/site-packages/joblib/parallel.py\u001b[0m in \u001b[0;36mretrieve\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m    907\u001b[0m             \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    908\u001b[0m                 \u001b[0;32mif\u001b[0m \u001b[0mgetattr\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_backend\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'supports_timeout'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;32mFalse\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 909\u001b[0;31m                     \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_output\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mextend\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mjob\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mget\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtimeout\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtimeout\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    910\u001b[0m                 \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    911\u001b[0m                     \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_output\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mextend\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mjob\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mget\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/opt/anaconda3/lib/python3.7/site-packages/joblib/_parallel_backends.py\u001b[0m in \u001b[0;36mwrap_future_result\u001b[0;34m(future, timeout)\u001b[0m\n\u001b[1;32m    560\u001b[0m         AsyncResults.get from multiprocessing.\"\"\"\n\u001b[1;32m    561\u001b[0m         \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 562\u001b[0;31m             \u001b[0;32mreturn\u001b[0m \u001b[0mfuture\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mresult\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtimeout\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mtimeout\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    563\u001b[0m         \u001b[0;32mexcept\u001b[0m \u001b[0mLokyTimeoutError\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    564\u001b[0m             \u001b[0;32mraise\u001b[0m \u001b[0mTimeoutError\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/opt/anaconda3/lib/python3.7/concurrent/futures/_base.py\u001b[0m in \u001b[0;36mresult\u001b[0;34m(self, timeout)\u001b[0m\n\u001b[1;32m    433\u001b[0m                 \u001b[0;32mraise\u001b[0m \u001b[0mCancelledError\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    434\u001b[0m             \u001b[0;32melif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_state\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0mFINISHED\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 435\u001b[0;31m                 \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m__get_result\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    436\u001b[0m             \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    437\u001b[0m                 \u001b[0;32mraise\u001b[0m \u001b[0mTimeoutError\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/opt/anaconda3/lib/python3.7/concurrent/futures/_base.py\u001b[0m in \u001b[0;36m__get_result\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m    382\u001b[0m     \u001b[0;32mdef\u001b[0m \u001b[0m__get_result\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    383\u001b[0m         \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_exception\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 384\u001b[0;31m             \u001b[0;32mraise\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_exception\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    385\u001b[0m         \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    386\u001b[0m             \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_result\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mUnboundLocalError\u001b[0m: local variable 'trial_num' referenced before assignment"
     ]
    }
   ],
   "source": [
    "num_cpus_to_use = max(multiprocessing.cpu_count() - 2, 1)\n",
    "num_cpus_to_use = 6\n",
    "print(\"num_cpus_to_use: \", num_cpus_to_use)\n",
    "\n",
    "metric_iteration = 4\n",
    "\n",
    "pred_epochs = 500; disc_epochs = 500 \n",
    "\n",
    "\n",
    "# num of trials for Bayesian search: initial and total (including initial)\n",
    "n_initial_points = max(5, num_cpus_to_use)\n",
    "n_calls = 50\n",
    "\n",
    "# our model name\n",
    "model = 'vae_conv_I'\n",
    "\n",
    "dataset_names = ['stocks', 'stocks2', 'air', 'sine', 'energy']\n",
    "# percs = [2, 5, 10, 20, 100]\n",
    "\n",
    "\n",
    "# to custom run specific data\n",
    "dataset_names = ['stocks2']\n",
    "percs = [ 20 ]\n",
    "\n",
    "\n",
    "# set random gen seed for reproducibiity\n",
    "set_seeds(42)\n",
    "\n",
    "main_start_time = time.time()    \n",
    "\n",
    "best_loss = 1e9\n",
    "for p in percs:  \n",
    "    for data_name in dataset_names:   \n",
    "        print(\"-\"*80)\n",
    "        print(f\"dataset: {data_name}, perc: {p}\")\n",
    "        # --------------------------------------------------------------------\n",
    "        ### file name to load\n",
    "        fname = f'{input_dir + data_name}_subsampled_train_perc_{p}.npz'\n",
    "        \n",
    "        ### read data        \n",
    "        loaded = np.load(fname)\n",
    "        ori_data = loaded['data']  \n",
    "        N = ori_data.shape[0]\n",
    "        # ori_data = ori_data[np.random.choice(N, N//2, replace=False), :]\n",
    "        # print(ori_data.shape)     \n",
    "        # --------------------------------------------------------------------\n",
    "        \n",
    "        optimizer = Optimizer(\n",
    "            dimensions = param_grid, # the hyperparameter space\n",
    "            base_estimator = \"GP\", # the surrogate\n",
    "            n_initial_points=n_initial_points, # the number of points to evaluate f(x) to start of\n",
    "            acq_func='EI', # the acquisition function\n",
    "            random_state=0, \n",
    "            n_jobs=num_cpus_to_use,\n",
    "        )\n",
    "        # --------------------------------------------------------------------\n",
    "        \n",
    "        num_loops = int(np.ceil(n_calls / num_cpus_to_use))\n",
    "        print(f'num_loops: {num_loops}')\n",
    "        \n",
    "        x_dict = {}\n",
    "        \n",
    "        for i in range(num_loops): \n",
    "            print('-'*80)\n",
    "            x = []; \n",
    "            pts = optimizer.ask(n_points=50)\n",
    "            for pt in pts: \n",
    "                pt_key = ''.join(str(s) for s in pt)\n",
    "                if pt_key in x_dict: continue\n",
    "                x_dict[pt_key] = 1\n",
    "                x.append(pt)\n",
    "                if len(x) == num_cpus_to_use: break\n",
    "\n",
    "    \n",
    "            y = Parallel(n_jobs=num_cpus_to_use)(delayed(objective)(v) for v in x)  # evaluate points in parallel\n",
    "            optimizer.tell(x, y)\n",
    "            print('\\nloop:', i, 'points:', x, y)\n",
    "            \n",
    "            cumu_time = np.round((time.time() - main_start_time)/60.0, 2)\n",
    "            print(\"iter:\", i, 'vae loss: ', optimizer.yi[-1], 'cumu_time_mins:', cumu_time)\n",
    "            \n",
    "            #technially should be outside the for loop but saving periodically in case it fails \n",
    "            hpo_results = pd.concat([\n",
    "                pd.DataFrame(optimizer.Xi),\n",
    "                pd.Series(optimizer.yi),\n",
    "            ], axis=1)\n",
    "\n",
    "            hpo_results.columns = dim_names + ['loss']\n",
    "            hpo_results.insert(0, 'dataset', data_name)\n",
    "            hpo_results.insert(1, 'perc', p)\n",
    "\n",
    "            file_name = f'hpo_results_{model}_{data_name}_perc_{p}_test.csv'   # \n",
    "            hpo_results.to_csv(os.path.join( hpo_dir, file_name), index=False)\n",
    "        \n",
    "\n",
    "end = time.time()\n",
    "elapsed_time = np.round((end - main_start_time)/60.0, 2)\n",
    "print(f\"All done in {elapsed_time} minutes!\")  "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Inspect HPO Results "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Load Data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 96,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Best score=0.1110\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>dataset</th>\n",
       "      <th>perc</th>\n",
       "      <th>latent_dim</th>\n",
       "      <th>n_layer1_in_25s</th>\n",
       "      <th>n_layer2_in_25s</th>\n",
       "      <th>n_layer3_in_50s</th>\n",
       "      <th>scaling_method</th>\n",
       "      <th>reconstruction_wt</th>\n",
       "      <th>loss</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>44</th>\n",
       "      <td>stocks</td>\n",
       "      <td>5</td>\n",
       "      <td>2</td>\n",
       "      <td>1</td>\n",
       "      <td>2</td>\n",
       "      <td>3</td>\n",
       "      <td>minmax</td>\n",
       "      <td>0.803957</td>\n",
       "      <td>0.1151</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>45</th>\n",
       "      <td>stocks</td>\n",
       "      <td>5</td>\n",
       "      <td>2</td>\n",
       "      <td>1</td>\n",
       "      <td>2</td>\n",
       "      <td>3</td>\n",
       "      <td>minmax</td>\n",
       "      <td>0.820629</td>\n",
       "      <td>0.1147</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>46</th>\n",
       "      <td>stocks</td>\n",
       "      <td>5</td>\n",
       "      <td>7</td>\n",
       "      <td>3</td>\n",
       "      <td>2</td>\n",
       "      <td>5</td>\n",
       "      <td>minmax</td>\n",
       "      <td>0.874990</td>\n",
       "      <td>0.1145</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>47</th>\n",
       "      <td>stocks</td>\n",
       "      <td>5</td>\n",
       "      <td>6</td>\n",
       "      <td>2</td>\n",
       "      <td>2</td>\n",
       "      <td>4</td>\n",
       "      <td>minmax</td>\n",
       "      <td>2.957224</td>\n",
       "      <td>0.1144</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>48</th>\n",
       "      <td>stocks</td>\n",
       "      <td>5</td>\n",
       "      <td>4</td>\n",
       "      <td>2</td>\n",
       "      <td>1</td>\n",
       "      <td>4</td>\n",
       "      <td>minmax</td>\n",
       "      <td>2.016872</td>\n",
       "      <td>0.1143</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>49</th>\n",
       "      <td>stocks</td>\n",
       "      <td>5</td>\n",
       "      <td>6</td>\n",
       "      <td>2</td>\n",
       "      <td>2</td>\n",
       "      <td>3</td>\n",
       "      <td>minmax</td>\n",
       "      <td>1.169333</td>\n",
       "      <td>0.1141</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>50</th>\n",
       "      <td>stocks</td>\n",
       "      <td>5</td>\n",
       "      <td>2</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>5</td>\n",
       "      <td>minmax</td>\n",
       "      <td>0.800000</td>\n",
       "      <td>0.1136</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>51</th>\n",
       "      <td>stocks</td>\n",
       "      <td>5</td>\n",
       "      <td>2</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>5</td>\n",
       "      <td>minmax</td>\n",
       "      <td>1.663378</td>\n",
       "      <td>0.1132</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>52</th>\n",
       "      <td>stocks</td>\n",
       "      <td>5</td>\n",
       "      <td>2</td>\n",
       "      <td>3</td>\n",
       "      <td>2</td>\n",
       "      <td>3</td>\n",
       "      <td>minmax</td>\n",
       "      <td>3.000000</td>\n",
       "      <td>0.1120</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>53</th>\n",
       "      <td>stocks</td>\n",
       "      <td>5</td>\n",
       "      <td>2</td>\n",
       "      <td>3</td>\n",
       "      <td>2</td>\n",
       "      <td>3</td>\n",
       "      <td>minmax</td>\n",
       "      <td>0.873437</td>\n",
       "      <td>0.1110</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   dataset  perc  latent_dim  n_layer1_in_25s  n_layer2_in_25s  \\\n",
       "44  stocks     5           2                1                2   \n",
       "45  stocks     5           2                1                2   \n",
       "46  stocks     5           7                3                2   \n",
       "47  stocks     5           6                2                2   \n",
       "48  stocks     5           4                2                1   \n",
       "49  stocks     5           6                2                2   \n",
       "50  stocks     5           2                1                1   \n",
       "51  stocks     5           2                1                1   \n",
       "52  stocks     5           2                3                2   \n",
       "53  stocks     5           2                3                2   \n",
       "\n",
       "    n_layer3_in_50s scaling_method  reconstruction_wt    loss  \n",
       "44                3         minmax           0.803957  0.1151  \n",
       "45                3         minmax           0.820629  0.1147  \n",
       "46                5         minmax           0.874990  0.1145  \n",
       "47                4         minmax           2.957224  0.1144  \n",
       "48                4         minmax           2.016872  0.1143  \n",
       "49                3         minmax           1.169333  0.1141  \n",
       "50                5         minmax           0.800000  0.1136  \n",
       "51                5         minmax           1.663378  0.1132  \n",
       "52                3         minmax           3.000000  0.1120  \n",
       "53                3         minmax           0.873437  0.1110  "
      ]
     },
     "execution_count": 96,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "test_data = 'stocks'\n",
    "test_perc = 5\n",
    "\n",
    "file_name = f'hpo_results_{model}_{test_data}_perc_{test_perc}_test.csv'   # \n",
    "hpt_results = pd.read_csv(os.path.join( hpo_dir, file_name))\n",
    "\n",
    "# function value at the minimum.\n",
    "print(\"Best score=%.4f\" % hpt_results['loss'].min())\n",
    "\n",
    "hpt_results.sort_values(by='loss', inplace=True, ascending=False)\n",
    "hpt_results.reset_index(drop=True, inplace=True)\n",
    "hpt_results.tail(10)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Check Convergence "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "hpt_results['loss'].plot()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
