{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import warnings\n",
    "warnings.filterwarnings('ignore')\n",
    "\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "import tensorflow as tf\n",
    "\n",
    "import random\n",
    "import pickle\n",
    "import os\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "from sklearn.model_selection import train_test_split\n",
    "from sklearn.preprocessing import MinMaxScaler, StandardScaler\n",
    "\n",
    "from helper import cal_metrics, f_get_minibatch\n",
    "from sklearn.neural_network import MLPClassifier\n",
    "\n",
    "from import_data import *\n",
    "import scipy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from model import VIME, FSNet_multiBern"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "p         = 10\n",
    "blocksize = 10  #overall, p*blocksize\n",
    "\n",
    "sigma_n = 1.0\n",
    "seed    = 1234\n",
    "\n",
    "max_labeled_samples   = 10\n",
    "max_unlabeled_samples = 2000\n",
    "\n",
    "model_name = 'proposed_mvBern' \n",
    "\n",
    "DATASET_PATH = 'TWOMOON/ns_{}nu_{}'.format(int(2*max_labeled_samples), max_unlabeled_samples)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "OUT_ITERATION = 100\n",
    "\n",
    "RESULTS  = np.zeros([OUT_ITERATION, 2])\n",
    "RESULTS2 = np.zeros([OUT_ITERATION, 2])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "out_itr = 0\n",
    "\n",
    "if not os.path.exists(save_path):\n",
    "    os.makedirs(save_path)\n",
    "\n",
    "seed    = 1234    \n",
    "seed    = seed + out_itr * 4\n",
    "\n",
    "tr_X, tr_Y, tr_Y_onehot = get_noisy_two_moons(n_samples=1000, n_feats=p, noise_twomoon=0.1, noise_nuisance=sigma_n, seed_=seed)\n",
    "UX, UY, UY_onehot       = get_noisy_two_moons(n_samples=1000, n_feats=p, noise_twomoon=0.1, noise_nuisance=sigma_n, seed_=seed+1)\n",
    "va_X, va_Y, va_Y_onehot = get_noisy_two_moons(n_samples=1000, n_feats=p, noise_twomoon=0.1, noise_nuisance=sigma_n, seed_=seed+2)\n",
    "te_X, te_Y, te_Y_onehot = get_noisy_two_moons(n_samples=1000, n_feats=p, noise_twomoon=0.1, noise_nuisance=sigma_n, seed_=seed+3)\n",
    "\n",
    "\n",
    "block_noise = 0.3\n",
    "tr_X = get_blockcorr(tr_X, blocksize, block_noise, seed)\n",
    "UX   = get_blockcorr(UX, blocksize, block_noise, seed+1)\n",
    "va_X = get_blockcorr(va_X, blocksize, block_noise, seed+2)\n",
    "te_X = get_blockcorr(te_X, blocksize, block_noise, seed+3)\n",
    "\n",
    "\n",
    "\n",
    "random.seed(seed)\n",
    "idx1 = random.sample(np.where(tr_Y==1)[0].tolist(), max_labeled_samples)\n",
    "idx0 = random.sample(np.where(tr_Y==0)[0].tolist(), max_labeled_samples)\n",
    "\n",
    "idx  = idx1 + idx0\n",
    "random.shuffle(idx)\n",
    "\n",
    "tr_X        = tr_X[idx]\n",
    "tr_Y        = tr_Y[idx]\n",
    "tr_Y_onehot = tr_Y_onehot[idx]\n",
    "\n",
    "tr_X_org = np.copy(tr_X)\n",
    "va_X_org = np.copy(va_X)\n",
    "te_X_org = np.copy(te_X)\n",
    "UX_org   = np.copy(UX)\n",
    "\n",
    "scaler = MinMaxScaler()\n",
    "scaler.fit(np.concatenate([tr_X, UX], axis=0))\n",
    "\n",
    "\n",
    "tr_X    = scaler.transform(tr_X)\n",
    "va_X    = scaler.transform(va_X)\n",
    "te_X    = scaler.transform(te_X)\n",
    "\n",
    "if max_unlabeled_samples > 1000:\n",
    "    UX_, UY_, UY_onehot_       = get_noisy_two_moons(n_samples=1000, n_feats=p, noise_twomoon=0.1, noise_nuisance=sigma_n, seed_=seed-1)\n",
    "    UX_                        = get_blockcorr(UX_, blocksize, block_noise, seed-1)\n",
    "    \n",
    "    UX        = np.concatenate([UX, UX_], axis=0)\n",
    "    UY        = np.concatenate([UY, UY_])\n",
    "    UY_onehot = np.concatenate([UY_onehot, UY_onehot_], axis=0)\n",
    "else:\n",
    "    UX        = UX[:max_unlabeled_samples]\n",
    "    UY        = UY[:max_unlabeled_samples]\n",
    "    UY_onehot = UY_onehot[:max_unlabeled_samples]    \n",
    "UX      = scaler.transform(UX)       "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Multivariate Bernoulli Generation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "cov = np.corrcoef(UX.T)\n",
    "remove_idx = []\n",
    "\n",
    "cov_new    = np.delete(cov, remove_idx, axis=0)\n",
    "cov_new    = np.delete(cov_new, remove_idx, axis=1)\n",
    "\n",
    "L          = scipy.linalg.cholesky(cov_new, lower=True)\n",
    "\n",
    "cov_new = []\n",
    "cov     = []\n",
    "\n",
    "\n",
    "def mask_generation(mb_size_, pi_):\n",
    "    '''\n",
    "        Phi(x; mu, sigma) = 1/2 * (1 + erf( (x-mu)/(sigma * sqrt(2)) )) \n",
    "        --> Phi(x; 0,1)   = 1/2 * (1 + erf( x/sqrt(2) )) \n",
    "    '''\n",
    "    if len(remove_idx) == 0:\n",
    "        epsilon = np.random.normal(loc=0., scale=1., size=[np.shape(L)[0], mb_size_])\n",
    "        g       = np.matmul(L, epsilon)\n",
    "    else:\n",
    "        present_idx = [i for i in range(x_dim) if i not in remove_idx]\n",
    "        epsilon     = np.random.normal(loc=0., scale=1., size=[np.shape(L)[0], mb_size_])\n",
    "        g2      = np.random.normal(loc=0., scale=1., size=[len(remove_idx), mb_size_])\n",
    "        g1      = np.matmul(L, epsilon)\n",
    "        g       = np.zeros([x_dim, mb_size_])\n",
    "\n",
    "        g[present_idx, :] = g1\n",
    "        g[remove_idx, :]  = g2\n",
    "\n",
    "    m = (1/2 * (1 + scipy.special.erf(g/np.sqrt(2)) ) < pi_).astype(float).T    \n",
    "    return m\n",
    "\n",
    "\n",
    "def copula_generation(mb_size_):\n",
    "    if len(remove_idx) == 0:\n",
    "        epsilon = np.random.normal(loc=0., scale=1., size=[np.shape(L)[0], mb_size_])\n",
    "        g       = np.matmul(L, epsilon)\n",
    "    else:\n",
    "        present_idx = [i for i in range(x_dim) if i not in remove_idx]\n",
    "        epsilon     = np.random.normal(loc=0., scale=1., size=[np.shape(L)[0], mb_size_])\n",
    "        g2      = np.random.normal(loc=0., scale=1., size=[len(remove_idx), mb_size_])\n",
    "        g1      = np.matmul(L, epsilon)\n",
    "        g       = np.zeros([x_dim, mb_size_])\n",
    "\n",
    "        g[present_idx, :] = g1\n",
    "        g[remove_idx, :]  = g2\n",
    "\n",
    "    return g.T"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "num_labeled   = np.shape(tr_X)[0]\n",
    "num_unlabeled = np.shape(UX)[0]\n",
    "num_all       = num_labeled + num_unlabeled\n",
    "\n",
    "x_dim        = np.shape(tr_X)[1]   \n",
    "y_dim        = np.shape(tr_Y_onehot)[1]\n",
    "\n",
    "y_type = 'categorical'"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# STEP1: SELF-SUPERVISION PHASE"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "reg_scale      = 0. #1e-3 #0. #1e-1 #0. #1e-4 #0. #1e-4#0. #1e-5 #0.#1e-8 #1e-5# 0. #1e-4 \n",
    "\n",
    "num_layers_e   = 3\n",
    "\n",
    "h_dim_e        = 100\n",
    "z_dim          = 10\n",
    "\n",
    "input_dims = {\n",
    "    'x_dim': x_dim,\n",
    "    'z_dim': z_dim\n",
    "} \n",
    "\n",
    "\n",
    "network_settings = {\n",
    "    'h_dim_e': h_dim_e,\n",
    "    'num_layers_e': num_layers_e,\n",
    "    'h_dim_d': h_dim_e,\n",
    "    'num_layers_d': num_layers_e,\n",
    "    \n",
    "    'fc_activate_fn': tf.nn.relu, #tf.nn.tanh, #tf.nn.relu,\n",
    "    'reg_scale': reg_scale\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "tf.reset_default_graph()\n",
    "\n",
    "# Turn on xla optimization\n",
    "config = tf.ConfigProto()\n",
    "# config = tf.ConfigProto(device_count = {'GPU': 0})\n",
    "config.gpu_options.allow_growth = True\n",
    "\n",
    "sess = tf.Session(config=config)\n",
    "model = VIME(sess, \"VIME\", input_dims, network_settings)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "step_size      = 1000\n",
    "iteration      = 500000\n",
    "\n",
    "mb_size        = 32\n",
    "learning_rate  = 1e-4\n",
    "\n",
    "keep_prob      = 1.0\n",
    "alpha          = 10.0\n",
    "\n",
    "p = 0.5\n",
    "\n",
    "\n",
    "x_mean = np.mean(UX, axis=0, keepdims=True)\n",
    "\n",
    "UX2 = UX[:1000]\n",
    "UX  = UX[1000:]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sess.run(tf.global_variables_initializer())\n",
    "saver       = tf.train.Saver()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print('=============================================')\n",
    "print('Start Feature Selection .... OUT_ITR {}, ALPHA {}, #MB {}, KEEP PROB {}'.format(out_itr, alpha, mb_size, keep_prob))\n",
    "print('=============================================')\n",
    "\n",
    "avg_loss     = 0\n",
    "va_avg_loss  = 0 \n",
    "\n",
    "avg_loss_r1    = 0\n",
    "va_avg_loss_r1 = 0 \n",
    "\n",
    "avg_loss_r2    = 0\n",
    "va_avg_loss_r2 = 0 \n",
    "\n",
    "max_acc      = 0.    \n",
    "min_loss     = 1e+8\n",
    "\n",
    "max_flag     = 20\n",
    "stop_flag    = 0\n",
    "\n",
    "for itr in range(iteration):       \n",
    "    x_mb, _       = f_get_minibatch(mb_size, UX, UX)\n",
    "    x2_mb         = np.tile(x_mean, [np.shape(x_mb)[0], 1])\n",
    "\n",
    "    m_mb          = mask_generation(x_mb.shape[0], p)\n",
    "\n",
    "    _, tmp_loss, tmp_loss_r1, tmp_loss_r2  = model.train_main(\n",
    "        x_=x_mb, x_bar_=x2_mb, m_=m_mb, alpha_=alpha, lr_train_=learning_rate, k_prob_=keep_prob\n",
    "    )\n",
    "\n",
    "    avg_loss      += tmp_loss/step_size\n",
    "    avg_loss_r1   += tmp_loss_r1/step_size\n",
    "    avg_loss_r2   += tmp_loss_r2/step_size\n",
    "\n",
    "    \n",
    "    x_mb, _       = f_get_minibatch(min(mb_size, np.shape(UX2)[0]), UX2, UX2)\n",
    "    x2_mb         = np.tile(x_mean, [np.shape(x_mb)[0], 1])\n",
    "\n",
    "    m_mb          = mask_generation(x_mb.shape[0], p)\n",
    "\n",
    "    tmp_loss, tmp_loss_r1, tmp_loss_r2   = model.get_loss_main(x_=x_mb, x_bar_=x2_mb, m_=m_mb, alpha_=alpha)\n",
    "    va_avg_loss     += tmp_loss/step_size\n",
    "    va_avg_loss_r1  += tmp_loss_r1/step_size\n",
    "    va_avg_loss_r2  += tmp_loss_r2/step_size\n",
    "            \n",
    "    if (itr+1)%step_size == 0:\n",
    "        stop_flag += 1\n",
    "        \n",
    "        print(\"ITR {:05d}  | TR: loss={:.3f} loss_Rx={:.3f} loss_Rm={:.3f}  | VA: loss={:.3f} loss_Rx={:.3f} loss_Rm={:.3f}\".format(\n",
    "            itr+1, avg_loss, avg_loss_r1, avg_loss_r2, va_avg_loss, va_avg_loss_r1, va_avg_loss_r2\n",
    "        ))\n",
    "        \n",
    "        \n",
    "        if va_avg_loss < min_loss:\n",
    "            print('saved...')\n",
    "            saver.save(sess, save_path + 'vime_trained')\n",
    "            np.savez(save_path + 'vime_encoder.npz', *sess.run(model.vars_encoder))\n",
    "            \n",
    "            min_loss  = va_avg_loss\n",
    "            \n",
    "            stop_flag = 0\n",
    "            \n",
    "        avg_loss     = 0\n",
    "        va_avg_loss  = 0 \n",
    "\n",
    "        avg_loss_r1     = 0\n",
    "        va_avg_loss_r1  = 0 \n",
    "        \n",
    "        avg_loss_r2     = 0\n",
    "        va_avg_loss_r2  = 0 \n",
    "        \n",
    "        if stop_flag >= max_flag:\n",
    "            break"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# STEP2: SUPERVISION PHASE"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "reg_scale      = 0. \n",
    "\n",
    "num_layers_p   = 1\n",
    "h_dim_p        = 100\n",
    "\n",
    "\n",
    "input_dims = {\n",
    "    'x_dim': x_dim,\n",
    "    'z_dim': z_dim,\n",
    "    'y_dim': y_dim,\n",
    "    'y_type': y_type\n",
    "} \n",
    "\n",
    "\n",
    "network_settings = {\n",
    "    'h_dim_e': h_dim_e,\n",
    "    'num_layers_e': num_layers_e,\n",
    "    'h_dim_p': h_dim_p,\n",
    "    'num_layers_p': num_layers_p,\n",
    "    \n",
    "    'fc_activate_fn_e': tf.nn.relu, \n",
    "    'fc_activate_fn_p': tf.nn.relu, \n",
    "    \n",
    "    'reg_scale': reg_scale\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "tf.reset_default_graph()\n",
    "\n",
    "# Turn on xla optimization\n",
    "config = tf.ConfigProto()\n",
    "# config = tf.ConfigProto(device_count = {'GPU': 0})\n",
    "config.gpu_options.allow_growth = True\n",
    "\n",
    "sess = tf.Session(config=config)\n",
    "model = FSNet_multiBern(sess, \"FSNet\", input_dims, network_settings)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "step_size      = 1000\n",
    "iteration      = 100000"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "mb_size        = 32\n",
    "mb_size        = min(mb_size, np.shape(tr_X)[0])\n",
    " \n",
    "learning_rate  = 1e-4\n",
    "\n",
    "keep_prob      = 1.0 \n",
    "lmbda          = 1.0 "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sess.run(tf.global_variables_initializer())\n",
    "saver       = tf.train.Saver()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "pretrained_encoder = np.load(save_path + 'vime_encoder.npz', allow_pickle=True)\n",
    "\n",
    "for i in range(len(list(pretrained_encoder))):\n",
    "    sess.run(tf.assign(model.vars_encoder[i], pretrained_encoder[list(pretrained_encoder)[i]]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "va_X        = np.copy(tr_X)\n",
    "va_Y_onehot = np.copy(tr_Y_onehot)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print('=============================================')\n",
    "print('Start Feature Selection .... OUT_ITR {}, LAMBDA {}, KEEP PROB {}'.format(out_itr, lmbda, keep_prob))\n",
    "print('=============================================')\n",
    "\n",
    "avg_loss      = 0.\n",
    "avg_loss_m0   = 0.   \n",
    "\n",
    "va_avg_loss      = 0.\n",
    "va_avg_loss_m0   = 0.\n",
    "\n",
    "max_acc      = 0.    \n",
    "min_loss     = 1e+8\n",
    "\n",
    "max_flag     = 100\n",
    "stop_flag    = 0\n",
    "\n",
    "num_selected_curr = 0\n",
    "num_selected_prev = 0\n",
    "\n",
    "\n",
    "for itr in range(iteration):\n",
    "    x_mb, y_mb     = f_get_minibatch(min(mb_size, np.shape(tr_X)[0]), tr_X, tr_Y_onehot)\n",
    "    x2_mb          = np.tile(x_mean, [np.shape(x_mb)[0], 1])\n",
    "    q_mb           = copula_generation(mb_size)\n",
    "    \n",
    "    _, tmp_loss, tmp_loss_m0  = model.train_finetune(x_=x_mb, x_bar_=x2_mb, y_=y_mb, q_=q_mb, lmbda_=lmbda, lr_train_=learning_rate, k_prob_=keep_prob)\n",
    "    avg_loss      += tmp_loss/step_size\n",
    "    avg_loss_m0   += tmp_loss_m0/step_size\n",
    "    \n",
    "\n",
    "    tmp_loss, tmp_loss_m0     = model.get_loss(x_=x_mb, x_bar_=x2_mb, y_=y_mb, q_=q_mb, lmbda_=lmbda)\n",
    "    \n",
    "    va_avg_loss      += tmp_loss/step_size\n",
    "    va_avg_loss_m0   += tmp_loss_m0/step_size    \n",
    "                \n",
    "    \n",
    "    if (itr+1)%step_size == 0:\n",
    "        tmp_mask = (sess.run(model.pi) > 0.5).astype(float)\n",
    "        q_mb     = copula_generation(np.shape(va_X)[0])\n",
    "        \n",
    "        tmp_y    = model.predict(x_=va_X, x_bar_=np.tile(x_mean, [np.shape(va_X)[0], 1]), q_=q_mb)\n",
    "        tmp_y2   = model.predict_final(x_=va_X, x_bar_=np.tile(x_mean, [np.shape(va_X)[0], 1]), m_=tmp_mask)\n",
    "        \n",
    "        va_auc, va_apc   = cal_metrics(va_Y_onehot, tmp_y)\n",
    "        va_auc2, va_apc2 = cal_metrics(va_Y_onehot, tmp_y2)\n",
    "\n",
    "        print(\"ITR {:05d}  | TR: loss={:.3f} loss_m0={:.3f}  | VA: loss={:.3f} loss_m0={:.3f} AUC:{:.3f}, AUC_Selected:{:.3f}\".format(\n",
    "            itr+1, avg_loss, avg_loss_m0, va_avg_loss, va_avg_loss_m0, va_auc, va_auc2\n",
    "        ))\n",
    "        \n",
    "        avg_loss      = 0.\n",
    "        avg_loss_m0   = 0.   \n",
    "\n",
    "        va_avg_loss      = 0.\n",
    "        va_avg_loss_m0   = 0."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "feature_importance = sess.run(model.pi)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "feature_importance"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
