{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0cf1e227",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"0\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "30715edd",
   "metadata": {},
   "outputs": [],
   "source": [
    "import time, random\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import matplotlib.pyplot as plt\n",
    "from collections import Counter\n",
    "from datetime import datetime\n",
    "\n",
    "import tensorflow as tf\n",
    "from tensorflow.keras import initializers\n",
    "from tensorflow.keras.layers import Dense, Input, Concatenate\n",
    "from tensorflow.keras.models import Model\n",
    "from tensorflow.keras.callbacks import EarlyStopping, Callback\n",
    "from tensorflow.keras import backend as K\n",
    "from tensorflow.keras.optimizers.legacy import Adam, Adagrad"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f54ee00f",
   "metadata": {},
   "outputs": [],
   "source": [
    "import PGNN_source as pg\n",
    "import NNNN_source as nn"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cede6081",
   "metadata": {},
   "outputs": [],
   "source": [
    "def compute_test_rmspe(y_test, mu_test, offset_test = None):\n",
    "    residual = ((y_test - mu_test)**2)/mu_test\n",
    "    if offset_test is not None:\n",
    "        residual = residual * offset_test\n",
    "    residual = residual[residual>0]\n",
    "    if np.sum(residual<=0) >0: print(str(np.sum(residual<=0))+' were excluded for rmspe')\n",
    "    return np.sqrt(np.mean(residual))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8e5ebc74",
   "metadata": {},
   "outputs": [],
   "source": [
    "def make_mean_model(nodes, layers):    \n",
    "    input_X = Input(shape=(np.shape(X_train)[1],), dtype='float32')\n",
    "    input_Z = Input(shape=(np.shape(Z_train)[1],), dtype='float32')\n",
    "    m  = Dense(nodes, activation=activation)(input_X)\n",
    "    for i in range(layers-1):\n",
    "        m  = Dense(nodes, activation=activation)(m)         \n",
    "    xb = Dense(1, activation='linear')(m)\n",
    "    zv = Dense(1, activation='linear', use_bias=False)(input_Z)    \n",
    "    mean_model = Model(inputs=[input_X, input_Z], outputs=[xb, zv])\n",
    "    return mean_model\n",
    "\n",
    "def make_mean_model_Nw(nodes, layers):    \n",
    "    input_X = Input(shape=(np.shape(X_train)[1],), dtype='float32')\n",
    "    m  = Dense(nodes, activation=activation)(input_X)\n",
    "    for i in range(layers-1):\n",
    "        m  = Dense(nodes, activation=activation)(m)          \n",
    "    xb = Dense(1, activation='linear')(m)    \n",
    "    mean_model = Model(inputs=[input_X], outputs=[xb])\n",
    "    return mean_model\n",
    "\n",
    "def make_mean_model_Nf(nodes, layers):    \n",
    "    input_X = Input(shape=(np.shape(X_train)[1],), dtype='float32')\n",
    "    input_Z = Input(shape=(np.shape(Z_train)[1],), dtype='float32')\n",
    "    m  = Dense(nodes, activation='sigmoid')(input_X)\n",
    "    for i in range(layers-1):\n",
    "        m  = Dense(nodes, activation=activation)(m)         \n",
    "    xb = Dense(1, activation='linear')(m)\n",
    "    zv = Dense(1, activation='linear', use_bias=False)(input_Z)        \n",
    "    mean_model = Model(inputs=[input_X, input_Z], outputs=[xb+zv])\n",
    "    return mean_model\n",
    "\n",
    "def make_mean_model_Pw(nodes, layers):    \n",
    "    input_X = Input(shape=(np.shape(X_train)[1],), dtype='float32')\n",
    "    m  = Dense(nodes, activation='sigmoid')(input_X)\n",
    "    for i in range(layers-1):\n",
    "        m  = Dense(nodes, activation=activation)(m)             \n",
    "    expxb = Dense(1, activation='exponential')(m)        \n",
    "    mean_model = Model(inputs=[input_X], outputs=[expxb])\n",
    "    return mean_model\n",
    "\n",
    "def make_mean_model_Pf(nodes, layers):    \n",
    "    input_X = Input(shape=(np.shape(X_train)[1],), dtype='float32')\n",
    "    input_Z = Input(shape=(np.shape(Z_train)[1],), dtype='float32')\n",
    "    m  = Dense(nodes, activation='sigmoid')(input_X)\n",
    "    for i in range(layers-1):\n",
    "        m  = Dense(nodes, activation=activation)(m)             \n",
    "    expxb = Dense(1, activation='exponential')(m)\n",
    "    expzv = Dense(1, activation='exponential', use_bias=False)(input_Z)        \n",
    "    mean_model = Model(inputs=[input_X, input_Z], outputs=[expxb*expzv])\n",
    "    return mean_model\n",
    "\n",
    "def make_mean_model_HL():    \n",
    "    input_X = Input(shape=(np.shape(X_train)[1],), dtype='float32')\n",
    "    input_Z = Input(shape=(np.shape(Z_train)[1],), dtype='float32')    \n",
    "    xb = Dense(1, activation='linear')(input_X)\n",
    "    zv = Dense(1, activation='linear', use_bias=False)(input_Z)\n",
    "    mean_model = Model(inputs=[input_X, input_Z], outputs=[xb, zv])\n",
    "    return mean_model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b56eb920",
   "metadata": {},
   "outputs": [],
   "source": [
    "def analyze_dataset(\n",
    "    model_names, X_train, Z_train, y_train, X_valid, Z_valid, y_valid, X_test, Z_test, y_test, nodes, layers, activation, \n",
    "    phi_init = 1., lam_init = 1., patience = 10, pretrain = 100, max_epochs = 2000, lr = 0.001, offset_test = None\n",
    "):\n",
    "\n",
    "    res_rmspe = {}\n",
    "    optimizer = Adam(learning_rate=lr)\n",
    "    callbacks = [EarlyStopping(monitor='val_loss', patience=patience)]    \n",
    "    N_train = len(y_train)    \n",
    "    batch_size, batch_ratio = N_train, 1.\n",
    "    \n",
    "    pg.seed_everything()\n",
    "    train_batch = tf.data.Dataset.from_tensor_slices((X_train, Z_train, y_train)).shuffle(N_train).batch(N_train)   \n",
    "\n",
    "    # N-NN\n",
    "    if 'Nw' in model_names:\n",
    "        K.clear_session() \n",
    "        pg.seed_everything()    \n",
    "        M = make_mean_model_Nw(nodes, layers)\n",
    "        M.compile(optimizer=optimizer, loss=tf.keras.losses.MeanSquaredError())\n",
    "        start_time = time.time()\n",
    "        M_history = M.fit([X_train], y_train, epochs=max_epochs, batch_size=batch_size, verbose=0, \n",
    "            callbacks=callbacks, validation_data=([X_valid], y_valid))\n",
    "        mu_test = np.float32(M([X_test])).T\n",
    "        res_rmspe['Nw'] = compute_test_rmspe(y_test, mu_test, offset_test)\n",
    "\n",
    "    # NF-NN\n",
    "    if 'Nf' in model_names:\n",
    "        K.clear_session() \n",
    "        pg.seed_everything()\n",
    "        M = make_mean_model_Nf(nodes, layers)\n",
    "        M.compile(optimizer=optimizer, loss=tf.keras.losses.MeanSquaredError())\n",
    "        start_time = time.time()\n",
    "        M_history = M.fit([X_train, Z_train], y_train, epochs=max_epochs, batch_size=batch_size, verbose=0, \n",
    "            callbacks=callbacks, validation_data=([X_valid, Z_valid], y_valid))\n",
    "        mu_test = np.float32(M([X_test, Z_test])).T\n",
    "        res_rmspe['Nf'] = compute_test_rmspe(y_test, mu_test, offset_test)\n",
    "\n",
    "    # NN-NN\n",
    "    if 'NN' in model_names:\n",
    "        K.clear_session() \n",
    "        pg.seed_everything()\n",
    "        M = make_mean_model(nodes, layers)\n",
    "        start_time = time.time()\n",
    "        res = nn.train_model(M, train_batch, [X_train, Z_train, y_train], [X_valid, Z_valid, y_valid],\n",
    "             nn.nn_hlik_loss, optimizer, phi_init, lam_init, batch_ratio, patience, pretrain, max_epochs)\n",
    "        mu_test = np.sum(M([X_test, Z_test]), axis=0).T\n",
    "        res_rmspe['NN'] = compute_test_rmspe(y_test, mu_test, offset_test)\n",
    "    \n",
    "    # P-NN\n",
    "    if 'Pw' in model_names:\n",
    "        K.clear_session() \n",
    "        pg.seed_everything()\n",
    "        M = make_mean_model_Pw(nodes, layers)\n",
    "        M.compile(optimizer=optimizer, loss=tf.keras.losses.Poisson())\n",
    "        start_time = time.time()\n",
    "        M_history = M.fit([X_train], y_train, epochs=max_epochs, batch_size=batch_size, verbose=0, \n",
    "            callbacks=callbacks, validation_data=([X_valid], y_valid))\n",
    "        mu_test = np.float32(M([X_test])).T\n",
    "        res_rmspe['Pw'] = compute_test_rmspe(y_test, mu_test, offset_test)\n",
    "\n",
    "    # PF-NN\n",
    "    if 'Pf' in model_names:\n",
    "        K.clear_session() \n",
    "        pg.seed_everything()\n",
    "        M = make_mean_model_Pf(nodes, layers)\n",
    "        M.compile(optimizer=optimizer, loss=tf.keras.losses.Poisson())\n",
    "        start_time = time.time()\n",
    "        M_history = M.fit([X_train, Z_train], y_train, epochs=max_epochs, batch_size=batch_size, verbose=0, \n",
    "            callbacks=callbacks, validation_data=([X_valid, Z_valid], y_valid))\n",
    "        mu_test = np.float32(M([X_test, Z_test])).T\n",
    "        res_rmspe['Pf'] = compute_test_rmspe(y_test, mu_test, offset_test)\n",
    "\n",
    "    # PG-NN\n",
    "    if 'PG' in model_names:\n",
    "        K.clear_session() \n",
    "        pg.seed_everything()\n",
    "        M = make_mean_model(nodes, layers)\n",
    "        start_time = time.time()\n",
    "        res = pg.train_model(M, train_batch, [X_train, Z_train, y_train], [X_valid, Z_valid, y_valid],\n",
    "             pg.pg_hlik_loss, optimizer, lam_init, batch_ratio, patience, pretrain, max_epochs)\n",
    "        mu_test = np.exp(np.sum(M([X_test, Z_test]), axis=0).T)\n",
    "        res_rmspe['PG'] = compute_test_rmspe(y_test, mu_test, offset_test)\n",
    "\n",
    "    # PG-GLM\n",
    "    if 'HL' in model_names:\n",
    "        K.clear_session() \n",
    "        pg.seed_everything()\n",
    "        M = make_mean_model_HL()\n",
    "        start_time = time.time()\n",
    "        res = pg.train_model(M, train_batch, [X_train, Z_train, y_train], [X_valid, Z_valid, y_valid],\n",
    "             pg.pg_hlik_loss, optimizer, lam_init, batch_ratio, patience, pretrain, max_epochs)\n",
    "        mu_test = np.exp(np.sum(M([X_test, Z_test]), axis=0).T)\n",
    "        res_rmspe['HL'] = compute_test_rmspe(y_test, mu_test, offset_test)\n",
    "        \n",
    "    return res_rmspe"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "49632597",
   "metadata": {},
   "outputs": [],
   "source": [
    "dir_name = os.getcwd()+'/data/'\n",
    "\n",
    "activation = 'sigmoid'\n",
    "nodes, layers = 10, 1\n",
    "lr = 0.01\n",
    "phi_init, lam_init = 0.5, 0.5\n",
    "patience, pretrain, max_epochs = 10, 100, 2000"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6a77db5c",
   "metadata": {},
   "source": [
    "# Epilepsy data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5ae0a85d",
   "metadata": {},
   "outputs": [],
   "source": [
    "data = pd.read_csv(dir_name+'epilepsy.csv')\n",
    "N, n_sub, n_num = 236, 59, 4\n",
    "\n",
    "split_num = np.zeros(N)\n",
    "pg.seed_everything()\n",
    "for i in range(n_sub):\n",
    "    split_num[i * n_num + np.random.choice((n_num-1), 1)] = 1\n",
    "split_num = split_num + 2.*(data['time']==4)\n",
    "data['split_num'] = split_num\n",
    "\n",
    "data_trvl = data[data['time']!=4]\n",
    "data_test = data[data['time']==4]\n",
    "data_train = data_trvl[data_trvl['split_num']==0]\n",
    "data_valid = data_trvl[data_trvl['split_num']==1]\n",
    "\n",
    "covariate_names = ['time', 'drug', 'base', 'age']\n",
    "subset_names = ['_train', '_valid', '_test', '_trvl']\n",
    "for subset in subset_names:\n",
    "    exec('temp_data = data'+subset)\n",
    "    exec('N'+subset+'= np.shape(temp_data)[0]')\n",
    "    exec('X'+subset+'= np.array(temp_data[covariate_names], dtype=np.float32)')\n",
    "    exec('y'+subset+'= np.array(temp_data[\"y\"], dtype=np.float32)')        \n",
    "    exec('z'+subset+'= np.array(temp_data[\"id\"].astype(\"int32\"))')\n",
    "    exec('Z'+subset+'= np.eye(n_sub)[z'+subset+'].astype(\"float32\")')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a325f58a",
   "metadata": {},
   "outputs": [],
   "source": [
    "model_names = ['Nw','Nf','NN','Pw','Pf','PG']\n",
    "rmspes = analyze_dataset(\n",
    "    model_names, X_trvl, Z_trvl, y_trvl, X_valid, Z_valid, y_valid, X_test, Z_test, y_test,\n",
    "    nodes, layers, activation, phi_init, lam_init, patience, pretrain, max_epochs)\n",
    "pd.DataFrame([np.round(list(rmspes.values()),3)], columns = model_names, index = ['RMSPE'])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7b322218",
   "metadata": {},
   "source": [
    "# CD4 data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b145bf20",
   "metadata": {},
   "outputs": [],
   "source": [
    "data = pd.read_csv(dir_name+'cd4.csv')\n",
    "\n",
    "data['y'] = data['cd4']\n",
    "data['zA225z'] = 1*(data['treatment'] == 'zA225z')\n",
    "data['zA400d'] = 1*(data['treatment'] == 'zA400d')\n",
    "data['zX400d'] = 1*(data['treatment'] == 'zX400d')\n",
    "\n",
    "N = np.shape(data)[0] # = 4612\n",
    "n_sub = np.shape(np.unique(data['id']))[0] # = 1038\n",
    "q = np.zeros(n_sub, dtype=np.int32)\n",
    "for i in range(n_sub):\n",
    "    q[i] = int(np.sum(data['id']==i))\n",
    "\n",
    "split_num = np.zeros(N)\n",
    "pg.seed_everything()\n",
    "count = 0\n",
    "for i in range(n_sub):\n",
    "    split_num[count + np.random.choice((q[i]-1), 1)] = 1\n",
    "    count += q[i]\n",
    "split_num = split_num + 2.*(data['last_visit']==1)\n",
    "data['split_num'] = split_num\n",
    "\n",
    "data_trvl = data[data['last_visit']!=1]\n",
    "data_test = data[data['last_visit']==1]\n",
    "data_train = data_trvl[data_trvl['split_num']==0]\n",
    "data_valid = data_trvl[data_trvl['split_num']==1]\n",
    "\n",
    "covariate_names = ['age', 'gender','week','zA225z', 'zA400d', 'zX400d']\n",
    "subset_names = ['_train', '_valid', '_test', '_trvl']\n",
    "for subset in subset_names:\n",
    "    exec('temp_data = data'+subset)\n",
    "    exec('N'+subset+'= np.shape(temp_data)[0]')\n",
    "    exec('X'+subset+'= np.array(temp_data[covariate_names], dtype=np.float32)')\n",
    "    exec('y'+subset+'= np.array(temp_data[\"y\"], dtype=np.float32)')        \n",
    "    exec('z'+subset+'= np.array(temp_data[\"id\"].astype(\"int32\"))')\n",
    "    exec('Z'+subset+'= np.eye(n_sub)[z'+subset+'].astype(\"float32\")')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "597c0a5e",
   "metadata": {},
   "outputs": [],
   "source": [
    "model_names = ['Nw','Nf','NN','Pw','Pf','PG']\n",
    "rmspes = analyze_dataset(\n",
    "    model_names, X_trvl, Z_trvl, y_trvl, X_valid, Z_valid, y_valid, X_test, Z_test, y_test,\n",
    "    nodes, layers, activation, phi_init, lam_init, patience, pretrain, max_epochs)\n",
    "pd.DataFrame([np.round(list(rmspes.values()),3)], columns = model_names, index = ['RMSPE'])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ce8c1c62",
   "metadata": {},
   "source": [
    "# Bolus data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5cbd69bd",
   "metadata": {},
   "outputs": [],
   "source": [
    "data = pd.read_csv(dir_name+'bolus.csv')\n",
    "N, n_sub, n_num = 780, 65, 12\n",
    "\n",
    "data['1mg'] = data['group']=='1mg'\n",
    "data['2mg'] = data['group']=='2mg'\n",
    "split_num = np.zeros(N)\n",
    "pg.seed_everything()\n",
    "for i in range(n_sub):\n",
    "    split_num[i * n_num + np.random.choice((n_num-1), 1)] = 1\n",
    "split_num = split_num + 2.*(data['time']==12)\n",
    "data['split_num'] = split_num\n",
    "\n",
    "data_trvl = data[data['time']!=12]\n",
    "data_test = data[data['time']==12]\n",
    "data_train = data_trvl[data_trvl['split_num']==0]\n",
    "data_valid = data_trvl[data_trvl['split_num']==1]\n",
    "\n",
    "covariate_names = ['time', '2mg', '1mg']\n",
    "subset_names = ['_train', '_valid', '_test', '_trvl']\n",
    "for subset in subset_names:\n",
    "    exec('temp_data = data'+subset)\n",
    "    exec('N'+subset+'= np.shape(temp_data)[0]')\n",
    "    exec('X'+subset+'= np.array(temp_data[covariate_names], dtype=np.float32)')\n",
    "    exec('y'+subset+'= np.array(temp_data[\"y\"], dtype=np.float32)')        \n",
    "    exec('z'+subset+'= np.array(temp_data[\"id\"].astype(\"int32\"))')\n",
    "    exec('Z'+subset+'= np.eye(n_sub)[z'+subset+'].astype(\"float32\")')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "475ab737",
   "metadata": {},
   "outputs": [],
   "source": [
    "model_names = ['Nw','Nf','NN','Pw','Pf','PG','HL']\n",
    "rmspes = analyze_dataset(\n",
    "    model_names, X_train, Z_train, y_train, X_valid, Z_valid, y_valid, X_test, Z_test, y_test,\n",
    "    nodes, layers, activation, phi_init, lam_init, patience, pretrain, max_epochs, lr)\n",
    "pd.DataFrame([np.round(list(rmspes.values()),3)], columns = model_names, index = ['RMSPE'])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2fd50440",
   "metadata": {},
   "source": [
    "# Owls data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "63b33e2f",
   "metadata": {},
   "outputs": [],
   "source": [
    "data = pd.read_csv(dir_name+'owls.csv')\n",
    "\n",
    "data['FT'] = 1*(data['FoodTreatment']=='Satiated')\n",
    "data['SP'] = 1*(data['SexParent']=='Male')\n",
    "data['y']  = data['SiblingNegotiation']/data['BroodSize']\n",
    "\n",
    "N = np.shape(data)[0]\n",
    "n_sub = np.shape(np.unique(data['id']))[0]\n",
    "split_num = np.zeros(N)\n",
    "q = np.zeros(n_sub, dtype=np.int32)\n",
    "for i in range(n_sub): q[i] = int(np.sum(data['id']==i))    \n",
    "    \n",
    "pg.seed_everything()\n",
    "for i in range(n_sub):\n",
    "    count = 0; valid_num, test_num = np.random.choice(q[i], 2, replace=False)\n",
    "    for j in range(N):\n",
    "        if data['id'][j]==i:\n",
    "            if count == valid_num: split_num[j] = 1\n",
    "            elif count == test_num: split_num[j] = 2\n",
    "            count += 1\n",
    "data['split'] = split_num\n",
    "\n",
    "data_valid = data[data['split']==1]\n",
    "data_train = data[data['split']==0]\n",
    "data_test = data[data['split']==2]\n",
    "data_trvl = data[data['split']!=2]\n",
    "\n",
    "covariate_names = ['ArrivalTime', 'FT', 'SP']\n",
    "subset_names = ['_train', '_valid', '_test', '_trvl']\n",
    "for subset in subset_names:\n",
    "    exec('temp_data = data'+subset)\n",
    "    exec('N'+subset+'= np.shape(temp_data)[0]')\n",
    "    exec('X'+subset+'= np.array(temp_data[covariate_names], dtype=np.float32)')\n",
    "    exec('y'+subset+'= np.array(temp_data[\"y\"], dtype=np.float32)')        \n",
    "    exec('z'+subset+'= np.array(temp_data[\"id\"].astype(\"int32\"))')\n",
    "    exec('Z'+subset+'= np.eye(n_sub)[z'+subset+'].astype(\"float32\")')\n",
    "\n",
    "offset_test = np.float32(data_test['BroodSize'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8b0cb858",
   "metadata": {},
   "outputs": [],
   "source": [
    "model_names = ['Nw','Nf','NN','Pw','Pf','PG','HL']\n",
    "rmspes = analyze_dataset(\n",
    "    model_names, X_train, Z_train, y_train, X_valid, Z_valid, y_valid, X_test, Z_test, y_test,\n",
    "    nodes, layers, activation, phi_init, lam_init, patience, pretrain, max_epochs, lr, offset_test)\n",
    "pd.DataFrame([np.round(list(rmspes.values()),3)], columns = model_names, index = ['RMSPE'])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "301c4f85",
   "metadata": {},
   "source": [
    "# Fruits data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5d57a682",
   "metadata": {},
   "outputs": [],
   "source": [
    "data = pd.read_csv(dir_name+'fruits.csv')\n",
    "\n",
    "data['amd'] = 1*(data['amd']=='clipped')\n",
    "data['rack'] = 1*(data['rack']==2) \n",
    "data['normal'] = 1*(data['status']=='Normal') \n",
    "data['trans'] = 1*(data['status'] =='Transplant') \n",
    "data['petri'] = 1*(data['status'] =='Petri.Plate') \n",
    "data['y']  = data['total.fruits']\n",
    "\n",
    "N = np.shape(data)[0]\n",
    "n_sub = np.shape(np.unique(data['id']))[0]\n",
    "\n",
    "split_num = np.zeros(N)\n",
    "q = np.zeros(n_sub, dtype=np.int32)\n",
    "for i in range(n_sub): q[i] = int(np.sum(data['id']==i))\n",
    "pg.seed_everything()\n",
    "for i in range(n_sub):\n",
    "    count = 0; valid_num, test_num = np.random.choice(q[i], 2, replace=False)\n",
    "    for j in range(N):\n",
    "        if data['id'][j]==i:\n",
    "            if count == valid_num: split_num[j] = 1\n",
    "            elif count == test_num: split_num[j] = 2\n",
    "            count += 1\n",
    "data['split'] = split_num\n",
    "\n",
    "data_valid = data[data['split']==1]\n",
    "data_train = data[data['split']==0]\n",
    "data_test = data[data['split']==2]\n",
    "data_trvl = data[data['split']!=2]\n",
    "\n",
    "covariate_names = ['nutrient', 'amd', 'rack', 'normal','trans', 'petri']\n",
    "subset_names = ['_train', '_valid', '_test', '_trvl']\n",
    "for subset in subset_names:\n",
    "    exec('temp_data = data'+subset)\n",
    "    exec('N'+subset+'= np.shape(temp_data)[0]')\n",
    "    exec('X'+subset+'= np.array(temp_data[covariate_names], dtype=np.float32)')\n",
    "    exec('y'+subset+'= np.array(temp_data[\"y\"], dtype=np.float32)')        \n",
    "    exec('z'+subset+'= np.array(temp_data[\"id\"].astype(\"int32\"))')\n",
    "    exec('Z'+subset+'= np.eye(n_sub)[z'+subset+'].astype(\"float32\")')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1791dbb1",
   "metadata": {},
   "outputs": [],
   "source": [
    "model_names = ['Nw','Nf','NN','Pw','Pf','PG','HL']\n",
    "rmspes = analyze_dataset(\n",
    "    model_names, X_train, Z_train, y_train, X_valid, Z_valid, y_valid, X_test, Z_test, y_test,\n",
    "    nodes, layers, activation, phi_init, lam_init, patience, pretrain, max_epochs, lr)\n",
    "pd.DataFrame([np.round(list(rmspes.values()),3)], columns = model_names, index = ['RMSPE'])"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
