{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8c15f99f",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"0\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3bf08179",
   "metadata": {},
   "outputs": [],
   "source": [
    "import time\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import matplotlib.pyplot as plt\n",
    "from collections import Counter\n",
    "from datetime import datetime\n",
    "\n",
    "import tensorflow as tf\n",
    "from tensorflow.keras import initializers\n",
    "from tensorflow.keras.layers import Dense, Input, Concatenate\n",
    "from tensorflow.keras.models import Model\n",
    "from tensorflow.keras.callbacks import EarlyStopping, Callback\n",
    "from tensorflow.keras import backend as K\n",
    "from tensorflow.keras.optimizers.legacy import Adam, Adagrad\n",
    "print(tf.__version__)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "692f448f",
   "metadata": {},
   "outputs": [],
   "source": [
    "import PGNN_source as pg"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8ce2cf67",
   "metadata": {},
   "outputs": [],
   "source": [
    "def make_mean_model_PF(nodes, activation):    \n",
    "    input_X = Input(shape=(np.shape(X_train)[1],), dtype='float32')\n",
    "    input_Z = Input(shape=(np.shape(Z_train)[1],), dtype='float32')\n",
    "    if len(nodes)!=0:\n",
    "        m  = Dense(nodes[0], activation=activation)(input_X)\n",
    "        for i in range(1,len(nodes)):\n",
    "            m  = Dense(nodes[i], activation=activation)(m)\n",
    "        expxb = Dense(1, activation='exponential')(m)\n",
    "    else: expxb = Dense(1, activation='exponential')(input_X)\n",
    "    expzv = Dense(1, activation='exponential', use_bias=False)(input_Z)        \n",
    "    mean_model = Model(inputs=[input_X, input_Z], outputs=[expxb*expzv])\n",
    "    return mean_model\n",
    "\n",
    "def make_mean_model_PG(nodes, activation):    \n",
    "    input_X = Input(shape=(np.shape(X_train)[1],), dtype='float32')\n",
    "    input_Z = Input(shape=(np.shape(Z_train)[1],), dtype='float32')\n",
    "    if len(nodes)!=0:\n",
    "        m  = Dense(nodes[0], activation=activation)(input_X)\n",
    "        for i in range(1,len(nodes)):\n",
    "            m  = Dense(nodes[i], activation=activation)(m)\n",
    "        xb = Dense(1, activation='linear')(m)\n",
    "    else: xb = Dense(1, activation='linear')(input_X)\n",
    "    zv = Dense(1, activation='linear', use_bias=False)(input_Z)\n",
    "    mean_model = Model(inputs=[input_X, input_Z], outputs=[xb, zv])\n",
    "    return mean_model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4132d483",
   "metadata": {},
   "outputs": [],
   "source": [
    "data_type = '100-100-1-gamma'\n",
    "n_simul = 2\n",
    "\n",
    "phi_init, lam_init = 0.8, 0.8\n",
    "lr = 0.005\n",
    "optimizer = Adam(learning_rate=lr)\n",
    "pretrain, moments_epochs, max_epochs = 50, 50, 150\n",
    "patience = 1000\n",
    "callbacks = [EarlyStopping(monitor='val_loss', patience=patience)]\n",
    "nodes = [10, 10, 10]\n",
    "activation = 'leaky_relu'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "91125a9f",
   "metadata": {},
   "outputs": [],
   "source": [
    "dir_name = os.getcwd() + '/'\n",
    "n_sub, n_num, lam, rand_dist = data_type.split('-')\n",
    "n_sub, n_num, lam = int(n_sub), int(n_num), float(lam)\n",
    "n_num_train, n_num_valid, n_num_test = int(n_num*0.6), int(n_num*0.2), int(n_num*0.2)\n",
    "N_train, N_valid, N_test = n_sub * np.array([n_num_train, n_num_valid, n_num_test])\n",
    "\n",
    "u_true = np.zeros((n_sub, n_simul))\n",
    "v_pred_0 = np.zeros((n_sub, n_simul))\n",
    "v_pred_1 = np.zeros((n_sub, n_simul))\n",
    "v_pred_2 = np.zeros((n_sub, n_simul))\n",
    "\n",
    "for simul_num in range(n_simul):\n",
    "    K.clear_session(); pg.seed_everything()\n",
    "    file_name = dir_name + 'simul-data-' + data_type + '-' + str(simul_num)\n",
    "    data = pd.read_csv(file_name+'.csv')\n",
    "    data_train = data[data['num'].isin(range(n_num_train))]\n",
    "    data_valid = data[data['num'].isin(range(n_num_train, n_num - n_num_test))]\n",
    "    data_test = data[-data['num'].isin(range(n_num - n_num_test))]\n",
    "    subset_names = ['_train', '_valid', '_test']\n",
    "    for subset in subset_names:\n",
    "        exec('temp_data = data'+subset)\n",
    "        exec('X'+subset+'= np.array(temp_data[[\"x\"+str(i) for i in range(5)]], dtype=np.float32)')\n",
    "        exec('y'+subset+'= np.array(temp_data[\"y\"], dtype=np.float32)')        \n",
    "        exec('z'+subset+'= np.array(temp_data[\"sub\"].astype(\"int32\"))')\n",
    "        exec('Z'+subset+'= np.eye(n_sub)[z'+subset+'].astype(\"float32\")')\n",
    "        exec('N'+subset+'= n_sub*n_num'+subset)\n",
    "        \n",
    "    u_true[:, simul_num] = np.array([data['u'][n_num*i] for i in range(n_sub)])\n",
    "    \n",
    "    batch_size, batch_ratio = N_train, 1.            \n",
    "    pg.seed_everything()\n",
    "    train_batch = tf.data.Dataset.from_tensor_slices((X_train, Z_train, y_train)).shuffle(N_train).batch(batch_size)    \n",
    "    \n",
    "    K.clear_session(); pg.seed_everything()\n",
    "    M = make_mean_model_PF(nodes, activation)\n",
    "    M.compile(optimizer=optimizer, loss=tf.keras.losses.Poisson())\n",
    "    M_history = M.fit([X_train, Z_train], y_train, epochs=(max_epochs+pretrain), batch_size=batch_size, verbose=0, \n",
    "        callbacks=callbacks, validation_data=([X_valid, Z_valid], y_valid))\n",
    "    wts_0 = M.get_weights()[-1]\n",
    "    v_pred_0[:,simul_num] = wts_0[:,0]\n",
    "\n",
    "    K.clear_session(); pg.seed_everything()\n",
    "    M = make_mean_model_PG(nodes, activation)\n",
    "    res_1 = pg.train_model(M, train_batch, [X_train, Z_train, y_train], [X_valid, Z_valid, y_valid],\n",
    "         pg.pg_hlik_loss, optimizer, lam_init, batch_ratio, patience, pretrain, max_epochs, moments_epochs, adjust=False)\n",
    "    wts_1 = res_1['wts']\n",
    "    v_pred_1[:,simul_num] = wts_1[-1][:,0]\n",
    "\n",
    "    K.clear_session(); pg.seed_everything()\n",
    "    M = make_mean_model_PG(nodes, activation)\n",
    "    res_2 = pg.train_model(M, train_batch, [X_train, Z_train, y_train], [X_valid, Z_valid, y_valid],\n",
    "         pg.pg_hlik_loss, optimizer, lam_init, batch_ratio, patience, pretrain, max_epochs, moments_epochs, adjust=True)\n",
    "    wts_2 = res_2['wts']\n",
    "    v_pred_2[:,simul_num] = wts_2[-1][:,0]\n",
    "\n",
    "u_pred_0, u_pred_1, u_pred_2 = np.exp(v_pred_0), np.exp(v_pred_1), np.exp(v_pred_2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "69c42196",
   "metadata": {},
   "outputs": [],
   "source": [
    "pg.seed_everything()\n",
    "data_number = np.random.choice(range(n_sub), 100,replace=False)\n",
    "data_number = range(100)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3c85ff2b",
   "metadata": {},
   "outputs": [],
   "source": [
    "# n_simul =2\n",
    "marker_size = 8\n",
    "fig, axs = plt.subplots(1, 3, figsize=(15,5), dpi=300)\n",
    "\n",
    "markers = ['rx', 'bo', 'ro', 'rx']\n",
    "axs[0].plot(u_true[data_number,1], u_pred_0[data_number,1], markers[1], mfc='none', markersize=marker_size)\n",
    "axs[0].plot(u_true[data_number,0], u_pred_0[data_number,0], markers[0], mfc='none', markersize=marker_size)\n",
    "axs[0].set_title('(a) PF-NN', fontsize = 13)\n",
    "axs[0].set_xlabel('u (true)', fontsize = 12)\n",
    "axs[0].set_ylabel('u (predicted)', fontsize = 12)\n",
    "axs[0].set_ylim([0, 6])\n",
    "axs[0].set_xlim([0, 6])\n",
    "axs[0].grid(True, ls='--')\n",
    "axs[0].axline((0, 0), slope=1)\n",
    "\n",
    "axs[1].plot(u_true[data_number,1], u_pred_1[data_number,1], markers[1], mfc='none', markersize=marker_size)\n",
    "axs[1].plot(u_true[data_number,0], u_pred_1[data_number,0], markers[0], mfc='none', markersize=marker_size)\n",
    "axs[1].set_title('(b) PG-NN without adjustment', fontsize = 13)\n",
    "axs[1].set_xlabel('u (true)', fontsize = 12)\n",
    "axs[1].set_ylabel('u (predicted)', fontsize = 12)\n",
    "axs[1].set_ylim([0, 6])\n",
    "axs[1].set_xlim([0, 6])\n",
    "axs[1].grid(True, ls='--')\n",
    "axs[1].axline((0, 0), slope=1)\n",
    "\n",
    "axs[2].plot(u_true[data_number,1], u_pred_2[data_number,1], markers[1], mfc='none', markersize=marker_size)\n",
    "axs[2].plot(u_true[data_number,0], u_pred_2[data_number,0], markers[0], mfc='none', markersize=marker_size)\n",
    "axs[2].set_title('(c) PG-NN with adjustment', fontsize = 13)\n",
    "axs[2].set_xlabel('u (true)', fontsize = 12)\n",
    "axs[2].set_ylabel('u (predicted)', fontsize = 12)\n",
    "axs[2].set_ylim([0, 6])\n",
    "axs[2].set_xlim([0, 6])\n",
    "axs[2].grid(True, ls='--')\n",
    "axs[2].axline((0, 0), slope=1)\n",
    "\n",
    "plt.show(fig)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a7e8d05f",
   "metadata": {},
   "outputs": [],
   "source": [
    "print('===== RMSE of u from PF-NN =====')\n",
    "print(np.round(np.sqrt(np.mean([\n",
    "    (u_true[data_number,1]-u_pred_0[data_number,1])**2,\n",
    "    (u_true[data_number,0]-u_pred_0[data_number,0])**2]        \n",
    ")),3))\n",
    "print('===== RMSE of u from PG-NN w/o adjust =====')\n",
    "print(np.round(np.sqrt(np.mean([\n",
    "    (u_true[data_number,1]-u_pred_1[data_number,1])**2,\n",
    "    (u_true[data_number,0]-u_pred_1[data_number,0])**2]        \n",
    ")),3))\n",
    "print('===== RMSE of u from PG-NN w/ adjust =====')\n",
    "print(np.round(np.sqrt(np.mean([\n",
    "    (u_true[data_number,1]-u_pred_2[data_number,1])**2,\n",
    "    (u_true[data_number,0]-u_pred_2[data_number,0])**2]        \n",
    ")),3))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
