{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5a425744",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import psutil\n",
    "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"0\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "04b77a04",
   "metadata": {},
   "outputs": [],
   "source": [
    "import time\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import matplotlib.pyplot as plt\n",
    "from tqdm import tqdm\n",
    "\n",
    "import tensorflow as tf\n",
    "import tensorflow_addons as tfa\n",
    "from tensorflow_addons.activations import sparsemax\n",
    "from tensorflow.keras import initializers\n",
    "from tensorflow.keras.layers import Dense, Input, Concatenate, Multiply\n",
    "from tensorflow.keras.models import Model\n",
    "from tensorflow.keras.callbacks import EarlyStopping, Callback\n",
    "from tensorflow.keras import backend as K\n",
    "from tensorflow.keras.optimizers.legacy import Adam, Adagrad\n",
    "print(tf.__version__)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3e66a7ef",
   "metadata": {},
   "outputs": [],
   "source": [
    "import PGNN_source as pg"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f8fe95f5",
   "metadata": {},
   "source": [
    "# Data generation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "96b6ea32",
   "metadata": {},
   "outputs": [],
   "source": [
    "from statsmodels.tsa.arima_process import ArmaProcess"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "903f06c2",
   "metadata": {},
   "outputs": [],
   "source": [
    "def mean_model(X, v):\n",
    "    # Here, v should be v_rep of dimension N. \n",
    "    mu = np.exp(0.2*np.sum(np.cos(X[:,:3]), axis=1)+np.sum(0.2/(X[:,3:5]**2+1), axis=1)+0.2+v)\n",
    "    return mu\n",
    "\n",
    "def generate_data(data_type, dir_name, p = 10, n_simul = 100):    \n",
    "    n_sub, n_num, lam, rand_dist = data_type.split('-')\n",
    "    n_sub, n_num, lam = int(n_sub), int(n_num), float(lam)\n",
    "    N = n_sub * n_num    \n",
    "    arma = ArmaProcess([1, -0.5], 1) # AR(1) with coeff 0.5        \n",
    "    for repeat in range(n_simul):\n",
    "        # generate data\n",
    "        np.random.seed(repeat)\n",
    "        X = arma.generate_sample(nsample=(N,p), axis=1)\n",
    "        u = np.random.gamma(1/lam, lam, n_sub)\n",
    "        u_rep = np.repeat(u, n_num)\n",
    "        v_rep = np.log(u_rep)\n",
    "        mu = mean_model(X, v_rep, b0, c0, c1)\n",
    "        y = np.random.poisson(mu)\n",
    "        # save data\n",
    "        data = pd.DataFrame(X, columns=[('x'+str(i)) for i in range(p)])\n",
    "        data['y'] = y\n",
    "        data['u'] = u_rep\n",
    "        data['mu'] = mu\n",
    "        data['sub'] = np.repeat(np.arange(n_sub), n_num)\n",
    "        data['num'] = np.tile(np.arange(n_num), n_sub)\n",
    "        file_name = dir_name + 'simul-data-' + data_type + '-' + str(repeat)\n",
    "        data.to_csv(file_name+'.csv', index=False)\n",
    "        del data      "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0d2ba928",
   "metadata": {},
   "outputs": [],
   "source": [
    "cwd = os.getcwd()\n",
    "dir_name = cwd + '/simulation_data_consistency/'\n",
    "data_type_list = ['200-20-0.5-gamma', '500-50-0.5-gamma', '1000-100-0.5-gamma']\n",
    "for data_type in data_type_list:\n",
    "    generate_data(data_type, dir_name, p=10, n_simul=100)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "873d1c2c",
   "metadata": {},
   "source": [
    "# Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "19048694",
   "metadata": {},
   "outputs": [],
   "source": [
    "cwd = os.getcwd()\n",
    "dir_name = cwd + '/simulation_data_consistency/'\n",
    "data_type_list = ['200-20-0.5-gamma', '500-50-0.5-gamma', '1000-100-0.5-gamma']\n",
    "n_simul = 100\n",
    "phi_init, lam_init = 0.8, 0.8\n",
    "\n",
    "lr = 0.005\n",
    "optimizer = Adam(learning_rate=lr)\n",
    "patience = 50\n",
    "pretrain, moments_epochs, max_epochs = 50, 50, 500\n",
    "\n",
    "nodes = [10, 10, 10]\n",
    "activation = 'leaky_relu'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "13ee0796",
   "metadata": {},
   "outputs": [],
   "source": [
    "lambda_estimate = np.zeros((n_simul, len(data_type_list)))\n",
    "\n",
    "for data_type in data_type_list:\n",
    "    \n",
    "    colnum = data_type_list.index(data_type)\n",
    "    print(data_type)\n",
    "    n_sub, n_num, lam, rand_dist = data_type.split('-')\n",
    "    n_sub, n_num, lam = int(n_sub), int(n_num), float(lam)\n",
    "    n_num_train, n_num_valid, n_num_test = int(n_num*0.6), int(n_num*0.2), int(n_num*0.2)\n",
    "    N_train, N_valid, N_test = n_sub * np.array([n_num_train, n_num_valid, n_num_test])\n",
    "    \n",
    "    for simul_num in tqdm(range(n_simul)):\n",
    "        \n",
    "        file_name = dir_name + 'simul-data-' + data_type + '-' + str(simul_num)\n",
    "        data = pd.read_csv(file_name+'.csv')\n",
    "        data_train = data[data['num'].isin(range(n_num_train))]\n",
    "        data_valid = data[data['num'].isin(range(n_num_train, n_num - n_num_test))]\n",
    "        data_test = data[-data['num'].isin(range(n_num - n_num_test))]\n",
    "        subset_names = ['_train', '_valid', '_test']\n",
    "        for subset in subset_names:\n",
    "            exec('temp_data = data'+subset)\n",
    "            exec('X'+subset+'= np.array(temp_data[[\"x\"+str(i) for i in range(5)]], dtype=np.float32)')\n",
    "            exec('y'+subset+'= np.array(temp_data[\"y\"], dtype=np.float32)')        \n",
    "            exec('z'+subset+'= np.array(temp_data[\"sub\"].astype(\"int32\"))')\n",
    "            exec('Z'+subset+'= np.eye(n_sub)[z'+subset+'].astype(\"float32\")')\n",
    "            exec('N'+subset+'= n_sub*n_num'+subset)\n",
    "        batch_size, batch_ratio = N_train, 1.\n",
    "        pg.seed_everything()\n",
    "        train_batch = tf.data.Dataset.from_tensor_slices((X_train, Z_train, y_train)).shuffle(N_train).batch(batch_size)    \n",
    "\n",
    "        K.clear_session(); pg.seed_everything()\n",
    "        input_X = Input(shape=(np.shape(X_train)[1],), dtype='float32')\n",
    "        input_Z = Input(shape=(np.shape(Z_train)[1],), dtype='float32')\n",
    "        m  = Dense(nodes[0], activation=activation)(input_X)\n",
    "        for i in range(1,len(nodes)):\n",
    "            m  = Dense(nodes[i], activation=activation)(m)\n",
    "        xb = Dense(1, activation='linear')(m)\n",
    "        zv = Dense(1, activation='linear', use_bias=False)(input_Z)\n",
    "        mean_model = Model(inputs=[input_X, input_Z], outputs=[xb, zv])\n",
    "\n",
    "        res = pg.train_model(mean_model, train_batch, [X_train, Z_train, y_train], [X_valid, Z_valid, y_valid],\n",
    "             pg.pg_hlik_loss, optimizer, lam_init, batch_ratio, patience, pretrain, max_epochs, moments_epochs)\n",
    "        lambda_estimate[simul_num, colnum] = res['lam']\n",
    "\n",
    "print(np.round(np.mean(lambda_estimate, axis=0),3))\n",
    "print(np.round(np.var(lambda_estimate, axis=0),3))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c56605ee",
   "metadata": {},
   "source": [
    "# Figure"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "abffc5df",
   "metadata": {},
   "outputs": [],
   "source": [
    "idx = range(1,4)\n",
    "labels = ['n=200, q=20','n=500, q=50','n=1000, q=100']\n",
    "plt.boxplot(lambda_estimate)\n",
    "plt.axhline(0.5, color='red', linewidth=0.5, linestyle='--')\n",
    "plt.xticks(idx, labels)\n",
    "plt.title('Estimates for variance component $\\lambda$')\n",
    "plt.savefig('boxplot.pdf', dpi=300, bbox_inches='tight')\n",
    "plt.show()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
