{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3acfa825",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"0\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fa62b4b2",
   "metadata": {},
   "outputs": [],
   "source": [
    "import time\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import matplotlib.pyplot as plt\n",
    "from tqdm import tqdm\n",
    "\n",
    "import tensorflow as tf\n",
    "import tensorflow_addons as tfa\n",
    "from tensorflow_addons.activations import sparsemax\n",
    "from tensorflow.keras import initializers\n",
    "from tensorflow.keras.layers import Dense, Input, Concatenate, Multiply\n",
    "from tensorflow.keras.models import Model\n",
    "from tensorflow.keras.callbacks import EarlyStopping, Callback\n",
    "from tensorflow.keras import backend as K\n",
    "from tensorflow.keras.optimizers.legacy import Adam, Adagrad\n",
    "print(tf.__version__)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f34ca826",
   "metadata": {},
   "outputs": [],
   "source": [
    "import PGNN_source as pg"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a443c79b",
   "metadata": {},
   "source": [
    "# Data generation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1bf5a52c",
   "metadata": {},
   "outputs": [],
   "source": [
    "def mean_model(X, v):\n",
    "    # Here, v should be v_rep of dimension N. \n",
    "    mu = np.exp(0.2*np.sum(np.cos(X[:,:6]), axis=1)+np.sum(0.2/(X[:,6:10]**2+1), axis=1)+0.2+v)\n",
    "    return mu\n",
    "\n",
    "def generate_data(data_type, dir_name, p = 100, n_simul = 100):    \n",
    "    n_sub, n_num, lam, rand_dist = data_type.split('-')\n",
    "    n_sub, n_num, lam = int(n_sub), int(n_num), float(lam)\n",
    "    N = n_sub * n_num    \n",
    "    for repeat in range(n_simul):\n",
    "        #generate data\n",
    "        np.random.seed(repeat)        \n",
    "        X = np.random.normal(0, 1, size=(N,p))\n",
    "        u = np.random.gamma(1/lam, lam, n_sub)\n",
    "        u_rep = np.repeat(u, n_num)\n",
    "        v_rep = np.log(u_rep)            \n",
    "        mu = mean_model(X, v_rep)\n",
    "        y = np.random.poisson(mu)\n",
    "        # save data\n",
    "        data = pd.DataFrame(X, columns=[('x'+str(i)) for i in range(p)])\n",
    "        data['y'] = y\n",
    "        data['u'] = u_rep\n",
    "        data['mu'] = mu\n",
    "        data['sub'] = np.repeat(np.arange(n_sub), n_num)\n",
    "        data['num'] = np.tile(np.arange(n_num), n_sub)        \n",
    "        file_name = dir_name + 'simul-data-' + data_type + '-' + str(repeat)\n",
    "        data.to_csv(file_name+'.csv', index=False)        \n",
    "        del data      "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dd4b5514",
   "metadata": {},
   "outputs": [],
   "source": [
    "dir_name = os.getcwd()+'/simulation_data/'\n",
    "data_type_list = ['10000-20-0.5-gamma']\n",
    "for data_type in data_type_list:\n",
    "    generate_data_cos(data_type, dir_name, p = 100, n_simul = 100)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ae4d94ae",
   "metadata": {},
   "source": [
    "# Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2e53c81c",
   "metadata": {},
   "outputs": [],
   "source": [
    "cwd = os.getcwd()\n",
    "dir_name = cwd + '/simulation_data/'\n",
    "data_type = '10000-20-0.5-gamma'\n",
    "data_type_list = [data_type]\n",
    "\n",
    "n_simul = 50\n",
    "phi_init, lam_init = 0.8, 0.8\n",
    "lr = 0.01\n",
    "optimizer = Adam(learning_rate=lr)\n",
    "patience = 20\n",
    "pretrain, moments_epochs, max_epochs = 100, 100, 1000\n",
    "\n",
    "p = 100\n",
    "nodes = [30, 10, 10]\n",
    "activation = 'leaky_relu'\n",
    "batch_size = 256"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "412f22df",
   "metadata": {},
   "outputs": [],
   "source": [
    "n_sub, n_num, lam, rand_dist = data_type.split('-')\n",
    "n_sub, n_num, lam = int(n_sub), int(n_num), float(lam)\n",
    "n_num_train, n_num_valid, n_num_test = int(n_num*0.6), int(n_num*0.2), int(n_num*0.2)\n",
    "N_train, N_valid, N_test = n_sub * np.array([n_num_train, n_num_valid, n_num_test])\n",
    "\n",
    "PGNN_attention_1 = np.zeros((100, n_simul))\n",
    "PGNN_attention_2 = np.zeros((100, n_simul))\n",
    "PGNN_attention_3 = np.zeros((100, n_simul))\n",
    "for simul_num in tqdm(range(100)):\n",
    "\n",
    "    file_name = dir_name + 'simul-data-' + data_type + '-' + str(simul_num)\n",
    "    data = pd.read_csv(file_name+'.csv')\n",
    "    data_train = data[data['num'].isin(range(n_num_train))]\n",
    "    data_valid = data[data['num'].isin(range(n_num_train, n_num - n_num_test))]\n",
    "    data_test = data[-data['num'].isin(range(n_num - n_num_test))]\n",
    "    subset_names = ['_train', '_valid', '_test']\n",
    "    for subset in subset_names:\n",
    "        exec('temp_data = data'+subset)\n",
    "        exec('X'+subset+'= np.array(temp_data[[\"x\"+str(i) for i in range(p)]], dtype=np.float32)')\n",
    "        exec('y'+subset+'= np.array(temp_data[\"y\"], dtype=np.float32)')        \n",
    "        exec('z'+subset+'= np.array(temp_data[\"sub\"].astype(\"int32\"))')\n",
    "        exec('Z'+subset+'= np.eye(n_sub)[z'+subset+'].astype(\"float32\")')\n",
    "        exec('N'+subset+'= n_sub*n_num'+subset)\n",
    "    batch_ratio = batch_size/N_train\n",
    "    pg.seed_everything()\n",
    "    train_batch = tf.data.Dataset.from_tensor_slices((X_train, Z_train, y_train)).shuffle(N_train).batch(batch_size)\n",
    "\n",
    "    K.clear_session(); pg.seed_everything()\n",
    "    input_X = Input(shape=(np.shape(X_train)[1],), dtype='float32')\n",
    "    input_Z = Input(shape=(np.shape(Z_train)[1],), dtype='float32')\n",
    "    attention_score_1 = Dense(np.shape(X_train)[1], activation=sparsemax)(input_X)\n",
    "    attention_score_2 = Dense(np.shape(X_train)[1], activation=sparsemax)(input_X)\n",
    "    attention_score_3 = Dense(np.shape(X_train)[1], activation=sparsemax)(input_X)\n",
    "    attention_model_1 = Model(input_X, attention_score_1)\n",
    "    attention_model_2 = Model(input_X, attention_score_2)\n",
    "    attention_model_3 = Model(input_X, attention_score_3)    \n",
    "    m1 = Multiply()([input_X, attention_score_1])\n",
    "    m2 = Multiply()([input_X, attention_score_2])\n",
    "    m3 = Multiply()([input_X, attention_score_3])\n",
    "    m = tf.keras.layers.Concatenate()([m1, m2, m3])\n",
    "    for i in range(len(nodes)):\n",
    "        m  = Dense(nodes[i], activation=activation)(m)\n",
    "    xb = Dense(1, activation='linear')(m)\n",
    "    zv = Dense(1, activation='linear', use_bias=False)(input_Z)\n",
    "    PGNN_model = Model(inputs=[input_X, input_Z], outputs=[xb, zv])\n",
    "\n",
    "    res = pg.train_model(PGNN_model, train_batch, [X_train, Z_train, y_train], [X_valid, Z_valid, y_valid],\n",
    "         pg.pg_hlik_loss, optimizer, lam_init, batch_ratio, patience, pretrain, max_epochs, moments_epochs)\n",
    "\n",
    "    PGNN_attention_1[:,simul_num] = np.mean(attention_model_1([X_train]), axis=0)\n",
    "    PGNN_attention_2[:,simul_num] = np.mean(attention_model_2([X_train]), axis=0)\n",
    "    PGNN_attention_3[:,simul_num] = np.mean(attention_model_3([X_train]), axis=0)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6dc3fd5a",
   "metadata": {},
   "source": [
    "# Figure"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "71429de1",
   "metadata": {},
   "outputs": [],
   "source": [
    "import seaborn as sns"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "68703eec",
   "metadata": {},
   "outputs": [],
   "source": [
    "PGNN_attention_average_1 = np.mean(PGNN_attention_1[:,:], axis=1)\n",
    "PGNN_attention_average_2 = np.mean(PGNN_attention_2[:,:], axis=1)\n",
    "PGNN_attention_average_3 = np.mean(PGNN_attention_3[:,:], axis=1)\n",
    "mean_attention = (PGNN_attention_average_1 + PGNN_attention_average_2 + PGNN_attention_average_3)/3."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e6ecb358",
   "metadata": {},
   "outputs": [],
   "source": [
    "sns.set(rc={'figure.figsize':(18,6)})\n",
    "sns.set_style('whitegrid', {'grid.linestyle': '--'})\n",
    "fig = sns.barplot(pd.DataFrame(mean_attention.reshape(1,-1)))\n",
    "fig.axes.set_ylim([0, 0.07])\n",
    "fig.axes.set_title('Feature Importance', fontsize=20)\n",
    "fig.axes.set_xlabel('Input features', fontsize=16)\n",
    "fig.axes.set_ylabel('Average attention score', fontsize=16)\n",
    "for item in fig.get_xticklabels():\n",
    "    item.set_rotation(70)\n",
    "fig.figure.savefig('attention.pdf', dpi=300, bbox_inches='tight')"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
