{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Python packages used in this code"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import pandas as pd\n",
    "import matplotlib\n",
    "import matplotlib.pyplot as plt\n",
    "import random\n",
    "from random import shuffle\n",
    "import os\n",
    "import pickle\n",
    "import time\n",
    "import sklearn\n",
    "import platform\n",
    "import sys\n",
    "from sklearn.kernel_ridge import KernelRidge\n",
    "from sklearn.gaussian_process.kernels import Matern, RBF\n",
    "from sklearn.base import BaseEstimator, RegressorMixin\n",
    "from sklearn.model_selection import train_test_split, GridSearchCV\n",
    "from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score\n",
    "from sklearn.linear_model import Ridge\n",
    "import warnings\n",
    "warnings.simplefilter('ignore')\n",
    "from IPython.display import clear_output\n",
    "from scipy import io\n",
    "import math\n",
    "import joblib\n",
    "\n",
    "## Keras\n",
    "import tensorflow\n",
    "from tensorflow import keras\n",
    "from tensorflow.keras.models import Sequential, Model, model_from_json, load_model\n",
    "from tensorflow.keras.layers import Dense, Input, Add, Lambda, Dropout, Subtract, Multiply, Concatenate, Dot, BatchNormalization, Activation, LeakyReLU, ReLU\n",
    "from tensorflow.keras.losses import mse\n",
    "import keras.backend as K\n",
    "from tensorflow.keras import regularizers\n",
    "from keras.regularizers import Regularizer\n",
    "os.environ['TF_FORCE_GPU_ALLOW_GROWTH'] = 'true'\n",
    "os.environ['TF_CPP_MIN_LOG_LEVEL']='2'\n",
    "tensorflow.get_logger().setLevel(\"ERROR\")\n",
    "import tensorflow as tf\n",
    "\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\"\"\"\n",
    "Environments\n",
    "\n",
    "--Platform--\n",
    "OS : Windows-10-10.0.19044-SP0\n",
    "--Version--\n",
    "python :  3.9.12 (main, Apr  4 2022, 05:22:27) [MSC v.1916 64 bit (AMD64)]\n",
    "numpy : 1.23.1\n",
    "pandas : 1.4.3\n",
    "sklearn : 1.1.1\n",
    "\"\"\"\n",
    "\n",
    "print('--Platform--')\n",
    "print('OS :', platform.platform())\n",
    "print('--Version--')\n",
    "print('python : ', sys.version)\n",
    "print('numpy :', np.__version__)\n",
    "print('pandas :', pd.__version__)\n",
    "print('sklearn :', sklearn.__version__)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Preparation"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Define the model class proposed in the paper"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class cls_L2SP_NN(BaseEstimator, RegressorMixin):\n",
    "    def __init__(self, learning_rate=0.01, epochs=10, n_frozen=0, lambda1=0.1, lambda2=0.1):\n",
    "        self.learning_rate = learning_rate\n",
    "        self.epochs = epochs\n",
    "        self.n_frozen = n_frozen\n",
    "        self.lambda1 = lambda1\n",
    "        self.lambda2 = lambda2\n",
    "    \n",
    "    def fit(self, X, y=None):\n",
    "        self.source_model = model_from_json(open(source_model_path_json, 'r').read())\n",
    "        self.source_model.load_weights(source_model_path_hdf5)\n",
    "        \n",
    "        h = Dense(units=1, name='Dense_add')(self.source_model.layers[10].output)\n",
    "        self.model = Model(inputs=self.source_model.inputs, outputs=h)\n",
    "        self.model.layers[11].set_weights([\n",
    "                self.source_model.layers[11].get_weights()[0].mean(1).reshape(-1,1),\n",
    "                self.source_model.layers[11].get_weights()[1].mean().reshape(1,)\n",
    "        ])\n",
    "\n",
    "        self.ref_model = Model(inputs=self.source_model.inputs, outputs=h)\n",
    "        self.ref_model.layers[11].set_weights([\n",
    "                self.source_model.layers[11].get_weights()[0].mean(1).reshape(-1,1),\n",
    "                self.source_model.layers[11].get_weights()[1].mean().reshape(1,)\n",
    "        ])\n",
    "        \n",
    "        for i_layer in [1, 4, 7, 9, 11]:\n",
    "            self.model.layers[i_layer].kernel_regularizer = CustomRegularizer(w0=self.ref_model.get_weights()[0], l1=self.lambda1, l2=self.lambda2)\n",
    "\n",
    "        self.model.compile(loss='mse',\n",
    "                     optimizer=keras.optimizers.Adam(self.learning_rate),\n",
    "                     metrics=['mae', 'mse'])\n",
    "        \n",
    "        fix_seed(373)\n",
    "        self.history = self.model.fit(\n",
    "            X,\n",
    "            y,\n",
    "            batch_size=1,\n",
    "            epochs=self.epochs,\n",
    "            validation_split = 0,\n",
    "            verbose=0\n",
    "        )\n",
    "        K.clear_session()\n",
    "        return self\n",
    "    \n",
    "    def predict(self, X):\n",
    "        return self.model.predict(X, verbose=0)\n",
    "\n",
    "    def score(self, X, y=None):\n",
    "        return -np.sum((y.values - self.predict(X))**2)\n",
    "    \n",
    "    def get_params(self, deep=True):\n",
    "        return {\n",
    "            'learning_rate' : self.learning_rate,\n",
    "            'epochs' : self.epochs\n",
    "        }\n",
    "    \n",
    "    def set_params(self, **parameters):\n",
    "        for parameter, value in parameters.items():\n",
    "            setattr(self, parameter, value)\n",
    "        return self  "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## fix_seed function"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def fix_seed(seed):\n",
    "    # Numpy\n",
    "    np.random.seed(seed)\n",
    "    # Tensorflow\n",
    "    tensorflow.random.set_seed(seed)\n",
    "    # for built-in random\n",
    "    random.seed(seed)\n",
    "    # for hash seed\n",
    "    os.environ[\"PYTHONHASHSEED\"] = str(seed)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Plot function"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "def plot_scatter(y_obs1_list, \n",
    "                 y_prd1_list,\n",
    "                 y_obs2_list, \n",
    "                 y_prd2_list,\n",
    "                 title_list, \n",
    "                 plt_row, \n",
    "                 plt_col, \n",
    "                 position_list, \n",
    "                 col_list1,\n",
    "                 alpha_list1,\n",
    "                 col_list2,\n",
    "                 alpha_list2,\n",
    "                 fig_size, \n",
    "                 save_name, \n",
    "                 title, \n",
    "                 show_flg=True):\n",
    "    fig = plt.figure(figsize=fig_size)\n",
    "\n",
    "    for i_plt in range(len(position_list)):\n",
    "        ax = fig.add_subplot(plt_row, plt_col, position_list[i_plt], \n",
    "                             title=title_list[i_plt], \n",
    "                             xlabel='Observation', \n",
    "                             ylabel='Prediction')\n",
    "        ax.scatter(y_obs1_list[i_plt], y_prd1_list[i_plt], color=col_list1[i_plt], alpha=alpha_list1[i_plt], zorder=10)\n",
    "        ax.scatter(y_obs2_list[i_plt], y_prd2_list[i_plt], color=col_list2[i_plt], alpha=alpha_list2[i_plt])\n",
    "        xy_min = min(ax.get_xlim()[0], ax.get_ylim()[0])\n",
    "        xy_max = max(ax.get_xlim()[1], ax.get_ylim()[1])\n",
    "        ax.axis('equal')\n",
    "        ax.axis('square')\n",
    "        ax.set_xlim([xy_min, xy_max])\n",
    "        ax.set_ylim([xy_min, xy_max])\n",
    "        ax.grid(color='gray', linestyle='dotted', linewidth=1, alpha=0.5)\n",
    "        ax.text(0.03, 0.93, 'Corr : '+str(round(np.corrcoef(y_prd2_list[i_plt], y_obs2_list[i_plt])[0,1], 4)), size=15, transform=ax.transAxes)\n",
    "        ax.text(0.03, 0.87, 'MSE : '+str(round(mean_squared_error(y_obs2_list[i_plt], y_prd2_list[i_plt]), 4)), size=15, transform=ax.transAxes)\n",
    "        ax.text(0.03, 0.81, 'MAE : '+str(round(mean_absolute_error(y_obs2_list[i_plt], y_prd2_list[i_plt]), 4)), size=15, transform=ax.transAxes)\n",
    "        _ = ax.plot([-300, 300], [-300, 300], color='gray', linewidth=0.5)\n",
    "\n",
    "    fig.tight_layout(rect=[0,0,1,0.90])\n",
    "    \n",
    "    plt.suptitle(title,fontsize=20)\n",
    "\n",
    "    fig.savefig(save_name)\n",
    "    if show_flg==False:\n",
    "        plt.close(fig)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Function to avoid zero division"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def avoid_zero(x, tsh=1, _add=0.1):\n",
    "    if np.abs(x) < tsh:\n",
    "        if x >= 0:\n",
    "            return tsh\n",
    "        else:\n",
    "            return -tsh\n",
    "    else:\n",
    "        return x"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## For network"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class PACModel(keras.Model):\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "        self.hidden1 = keras.layers.Dense(256, input_shape=(21,), name='Dense1-2')\n",
    "        self.hidden2 = keras.layers.Dense(64, name='Dense2-2')\n",
    "        self.hidden3 = keras.layers.Dense(32, name='Dense3-2')\n",
    "        self.hidden4 = keras.layers.Dense(16, name='Dense4-2')\n",
    "        self.out = keras.layers.Dense(1, name='Dense5-2')\n",
    "        \n",
    "    def forward(self, x):\n",
    "        x = keras.activations.relu(self.hidden1(x))\n",
    "        # x = Dropout(0.1)(x)\n",
    "        x = keras.activations.relu(self.hidden2(x))\n",
    "        # x = Dropout(0.1)(x)\n",
    "        x = keras.activations.relu(self.hidden3(x))\n",
    "        x = keras.activations.relu(self.hidden4(x))\n",
    "        x = self.out(x)\n",
    "        return x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class NNModel(keras.Model):\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "        self.hidden1 = keras.layers.Dense(256, input_shape=(21,), name='Dense1')\n",
    "        self.hidden2 = keras.layers.Dense(64, name='Dense2')\n",
    "        self.hidden3 = keras.layers.Dense(32, name='Dense3')\n",
    "        self.hidden4 = keras.layers.Dense(16, name='Dense4')\n",
    "        self.out = keras.layers.Dense(6, name='Dense5')\n",
    "        \n",
    "    def forward(self, x):\n",
    "        x = keras.activations.relu(self.hidden1(x))\n",
    "        # x = Dropout(0.1)(x)\n",
    "        x = keras.activations.relu(self.hidden2(x))\n",
    "        # x = Dropout(0.1)(x)\n",
    "        x = keras.activations.relu(self.hidden3(x))\n",
    "        x = keras.activations.relu(self.hidden4(x))\n",
    "        x = self.out(x)\n",
    "        return x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def fun_mask(x, wk):\n",
    "    if np.abs(x)>wk:\n",
    "        return 1\n",
    "    else:\n",
    "        return 0\n",
    "\n",
    "def loss_function(pred_y, y):\n",
    "    # return K.mean(K.mean(keras.losses.mean_squared_error(y, pred_y)))\n",
    "    return K.mean(keras.losses.mean_squared_error(y, pred_y))\n",
    "    # return K.mean(tf.math.square(tf.math.subtract(y, pred_y)))\n",
    "\n",
    "def np_to_tensor(list_of_numpy_objs):\n",
    "    return (tf.convert_to_tensor(obj) for obj in list_of_numpy_objs)\n",
    "\n",
    "def compute_loss(model, x, y, loss_fn=loss_function):\n",
    "    logits = model.forward(x)\n",
    "    mse = loss_fn(y, logits)\n",
    "    return mse, logits\n",
    "\n",
    "\n",
    "def compute_gradients(model, x, y, loss_fn=loss_function):\n",
    "    with tf.GradientTape() as tape:\n",
    "        loss, _ = compute_loss(model, x, y, loss_fn)\n",
    "    return tape.gradient(loss, model.trainable_variables), loss\n",
    "\n",
    "def apply_gradients(optimizer, gradients, variables):\n",
    "    optimizer.apply_gradients(zip(gradients, variables))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def copy_PACmodel(model, x):\n",
    "    '''Copy model weights to a new model.\n",
    "    \n",
    "    Args:\n",
    "        model: model to be copied.\n",
    "        x: An input example. This is used to run\n",
    "            a forward pass in order to add the weights of the graph\n",
    "            as variables.\n",
    "    Returns:\n",
    "        A copy of the model.\n",
    "    '''\n",
    "    copied_model = PACModel()\n",
    "    \n",
    "    # If we don't run this step the weights are not \"initialized\"\n",
    "    # and the gradients will not be computed.\n",
    "    copied_model.forward(tf.convert_to_tensor(x))\n",
    "    \n",
    "    copied_model.set_weights(model.get_weights())\n",
    "    return copied_model"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Main codes"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Load data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_path = '../10_Data/sarcos_inv.mat'\n",
    "test_path = '../10_Data/sarcos_inv_test.mat'\n",
    "axis_names = ['Position1','Position2','Position3','Position4','Position5','Position6','Position7',\n",
    "              'Velocity1','Velocity2','Velocity3','Velocity4','Velocity5','Velocity6','Velocity7',\n",
    "              'Acceleration1','Acceleration2','Acceleration3','Acceleration4','Acceleration5','Acceleration6','Acceleration7',\n",
    "              'Torque1','Torque2','Torque3','Torque4','Torque5','Torque6','Torque7']\n",
    "\n",
    "sar_train_all = io.loadmat(train_path)\n",
    "sar_test = io.loadmat(test_path)\n",
    "sar_train_all = pd.DataFrame(sar_train_all['sarcos_inv'], columns=axis_names)\n",
    "sar_test = pd.DataFrame(sar_test['sarcos_inv_test'], columns=axis_names)\n",
    "sar_train = sar_train_all.iloc[:30000, :]\n",
    "\n",
    "x_train = sar_train.iloc[:,0:21]\n",
    "x_test = sar_test.iloc[:,0:21]\n",
    "y_train_all = sar_train.iloc[:,21:]\n",
    "y_test_all = sar_test.iloc[:,21:]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Scaling\n",
    "x_mean = x_train.mean()\n",
    "x_std = x_train.std()\n",
    "y_mean_all = y_train_all.mean()\n",
    "y_std_all = y_train_all.std()\n",
    "\n",
    "x_train_scal = (x_train - x_mean)/x_std\n",
    "x_test_scal = (x_test - x_mean)/x_std\n",
    "y_train_all_scal = (y_train_all - y_mean_all)/y_std_all\n",
    "y_test_all_scal = (y_test_all - y_mean_all)/y_std_all"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## User parameter setting"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fix_seed(373)\n",
    "target_name_list = ['Torque1','Torque2','Torque3','Torque4','Torque5','Torque6','Torque7']\n",
    "n_sample_list = [5, 10, 15, 20, 30, 40, 50]\n",
    "max_itr = 20\n",
    "\n",
    "dim_x = 21\n",
    "n_all = 30000\n",
    "num_SourceTasks = 16\n",
    "\n",
    "# Kernel setting\n",
    "kernel_name = 'rbf'\n",
    "nu_ = 1.5\n",
    "i_seed = 373"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Make training sample ID list"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "fix_seed(373)\n",
    "sample_list = list()\n",
    "for n_try in range(max_itr):\n",
    "    fix_seed(n_try)\n",
    "    tmp_list = list(range(n_all))\n",
    "    random.shuffle(tmp_list)\n",
    "    sample_list.append(tmp_list)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class cls_PAC_NN(BaseEstimator, RegressorMixin):\n",
    "    def __init__(self, learning_rate=0.01, epochs=10):\n",
    "        self.learning_rate = learning_rate\n",
    "        self.epochs = epochs\n",
    "    \n",
    "    def fit(self, X, y=None):\n",
    "        self.x_target, self.y_target = np_to_tensor((X.values, y.values.reshape(-1,1)))\n",
    "        \n",
    "        self.model = PACModel()\n",
    "        self.model.forward(self.x_target)\n",
    "        self.model.set_weights(PAC_weights)\n",
    "        \n",
    "        self.train_loss_all = list()\n",
    "        for ep in range(self.epochs):\n",
    "            for j in range(len(self.model.layers)-1):\n",
    "                self.model.layers[j].kernel = tf.Variable(tf.add(tf.multiply(self.model.layers[j].kernel, tf.cast(tf.subtract(1, mask_mat_list[j]), tf.float32)), wu[j]), name='Dense'+str(j+1)+'-2/kernel')\n",
    "                self.model.layers[j].bias = tf.Variable(tf.multiply(1, self.model.layers[j].bias), name='Dense'+str(j+1)+'-2/bias')\n",
    "            self.model.layers[4].kernel = tf.Variable(tf.multiply(1, self.model.layers[4].kernel), name='Dense5-2/kernel')\n",
    "            self.model.layers[4].bias = tf.Variable(tf.multiply(1, self.model.layers[4].bias), name='Dense5-2/bias')\n",
    "\n",
    "            self.model.forward(self.x_target)\n",
    "            with tf.GradientTape() as train_tape:\n",
    "                train_loss, _ = compute_loss(self.model, self.x_target, self.y_target)\n",
    "            self.train_loss_all.append(train_loss.numpy())\n",
    "\n",
    "            gradients = train_tape.gradient(train_loss, self.model.trainable_variables)\n",
    "\n",
    "            keras.optimizers.SGD(self.learning_rate).apply_gradients(zip(gradients, self.model.trainable_variables))\n",
    "            \n",
    "        self.wp_kernel = [0,0,0,0]\n",
    "        for i_layer in range(len(self.wp_kernel)):\n",
    "            self.wp_kernel[i_layer] = tf.Variable(tf.multiply(self.model.layers[i_layer].kernel, tf.cast(tf.subtract(1, mask_mat_list[i_layer]), tf.float32)), name=self.model.layers[i_layer].kernel.name)\n",
    "            \n",
    "        for i_layer in range(len(self.model.layers)-1):\n",
    "            self.model.layers[i_layer].kernel = tf.Variable(tf.add(wu[i_layer], self.wp_kernel[i_layer]), name=self.model.layers[i_layer].kernel.name)\n",
    "            self.model.layers[i_layer].bias = tf.Variable(tf.multiply(1, self.model.layers[i_layer].bias), name=self.model.layers[i_layer].bias.name)\n",
    "            \n",
    "        return self\n",
    "    \n",
    "    def predict(self, X):\n",
    "        x_pred, _ = np_to_tensor((X.values, X.values))\n",
    "        return self.model.forward(x_pred).numpy().reshape(-1)\n",
    "\n",
    "    def score(self, X, y=None):\n",
    "        return -np.sum((y.values - self.predict(X))**2)\n",
    "    \n",
    "    def get_params(self, deep=True):\n",
    "        return {\n",
    "            'learning_rate' : self.learning_rate,\n",
    "            'epochs' : self.epochs\n",
    "        }\n",
    "    \n",
    "    def set_params(self, **parameters):\n",
    "        for parameter, value in parameters.items():\n",
    "            setattr(self, parameter, value)\n",
    "        return self  "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Storing dataframe\n",
    "df_result = pd.DataFrame(columns=['data_name','n_sample','n_itr','type', 'MSE', 'Corr', 'MAE', 'R2'])\n",
    "\n",
    "#PAC-Net\n",
    "prune_rate = 0.1\n",
    "lr_A = 0.01\n",
    "epochs_A = 1000\n",
    "lr_C = 0.01\n",
    "\n",
    "t0 = time.time()\n",
    "# Repeat for the different target torques\n",
    "for target_name in target_name_list:\n",
    "    print(target_name)\n",
    "    fix_seed(373)\n",
    "    \n",
    "    # Make target outputs and source features\n",
    "    y_train = y_train_all[target_name].copy()\n",
    "    y_test = y_test_all[target_name].copy()\n",
    "    ys_train_all_scal = y_train_all_scal[[s for s in target_name_list if s!=target_name]].copy()\n",
    "    x_source, y_source = np_to_tensor((x_train_scal.values, ys_train_all_scal.values.reshape(-1,6)))\n",
    "\n",
    "    source_model_path_json = '../30_Output/10_Model/100_MakeSourceModel/100_Model_'+target_name+'-'+str(i_seed)+'.json'\n",
    "    source_model_path_hdf5 = '../30_Output/10_Model/100_MakeSourceModel/100_Model_'+target_name+'-'+str(i_seed)+'.hdf5'\n",
    "    source_model = model_from_json(open(source_model_path_json, 'r').read())\n",
    "    source_model.load_weights(source_model_path_hdf5)\n",
    "    \n",
    "    #PAC-Net\n",
    "    ## pruning\n",
    "    i_layer_list = [1, 3, 5, 7, 9]\n",
    "    all_weight = np.array([])\n",
    "    for i in range(len(i_layer_list)):\n",
    "        all_weight = np.concatenate([all_weight, source_model.layers[i_layer_list[i]].get_weights()[0].reshape(-1)])\n",
    "    all_weight = np.abs(all_weight)\n",
    "\n",
    "    k = np.int64(np.floor(len(all_weight)*(1-prune_rate)))\n",
    "    wk = sorted(all_weight.ravel())[-k]\n",
    "\n",
    "    vfunc = np.vectorize(fun_mask)\n",
    "    mask_mat_list = []\n",
    "    for i in range(len(i_layer_list)):\n",
    "        tmp_weight = source_model.layers[i_layer_list[i]].get_weights()[0]\n",
    "        mask_mat = vfunc(tmp_weight, wk=wk)\n",
    "        mask_mat = mask_mat.astype('float32')\n",
    "        mask_mat_list.append(mask_mat)\n",
    "\n",
    "    mask_mat_list_tensor =[mask_mat_list[0], mask_mat_list[1], mask_mat_list[2], mask_mat_list[3], mask_mat_list[4]]\n",
    "    mask_mat_list_tensor[0], mask_mat_list_tensor[1], mask_mat_list_tensor[2], mask_mat_list_tensor[3], mask_mat_list_tensor[4] = np_to_tensor((mask_mat_list[0], mask_mat_list[1], mask_mat_list[2], mask_mat_list[3], mask_mat_list[4]))\n",
    "\n",
    "    prune_model = NNModel()\n",
    "    for i_layer in range(len(i_layer_list)):\n",
    "        prune_model.layers[i_layer].kernel = source_model.layers[i_layer_list[i_layer]].kernel\n",
    "        prune_model.layers[i_layer].bias = source_model.layers[i_layer_list[i_layer]].bias\n",
    "\n",
    "    # Allocation\n",
    "    print('  --- Allocation')\n",
    "    for ep in range(epochs_A):\n",
    "        prune_model.forward(x_source)\n",
    "        with tf.GradientTape() as train_tape:\n",
    "            train_loss, _ = compute_loss(prune_model, x_source, y_source)\n",
    "\n",
    "        gradients = train_tape.gradient(train_loss, prune_model.trainable_variables)\n",
    "\n",
    "        k = 0\n",
    "        for j in range(len(prune_model.layers)):\n",
    "            prune_model.layers[j].kernel = tf.Variable(tf.multiply(tf.subtract(prune_model.layers[j].kernel, tf.multiply(lr_A, gradients[k])), mask_mat_list_tensor[j]), name='Dense'+str(j)+'/kernel')\n",
    "            prune_model.layers[j].bias = tf.Variable(tf.subtract(prune_model.layers[j].bias, tf.multiply(lr_A, gradients[k+1])), name='Dense'+str(j)+'/bias')\n",
    "            k += 2\n",
    "    \n",
    "    # Repeat for the different number of samples\n",
    "    for num_train in n_sample_list:\n",
    "\n",
    "        # Repeat for the different sample splits\n",
    "        for n_itr in range(max_itr):\n",
    "            print(target_name+'   n : '+str(num_train)+',  try : '+str(n_itr))\n",
    "            t1 = time.time()\n",
    "\n",
    "            # Make training data\n",
    "            sample_id = sample_list[n_itr][:num_train]\n",
    "            x_train_tmp = x_train.iloc[sample_id,]\n",
    "            y_train_tmp = y_train.iloc[sample_id,]\n",
    "\n",
    "            # Scaling parameters\n",
    "            ## Inputs\n",
    "            x_mean_tmp = x_train_tmp.mean()\n",
    "            x_std_tmp = x_train_tmp.std()\n",
    "            x_train_scal_tmp = (x_train_tmp - x_mean_tmp) / x_std_tmp.replace(0,1)\n",
    "            x_test_scal_tmp = (x_test - x_mean_tmp) / x_std_tmp.replace(0,1)\n",
    "            ## Outputs\n",
    "            y_mean_tmp = y_train_tmp.mean()\n",
    "            y_std_tmp = y_train_tmp.std()\n",
    "            y_train_scal_tmp = (y_train_tmp - y_mean_tmp) / y_std_tmp\n",
    "            y_test_scal_tmp = (y_test - y_mean_tmp) / y_std_tmp\n",
    "            # For NN\n",
    "            x_target, y_target = np_to_tensor((x_train_scal_tmp.values, y_train_scal_tmp.values.reshape(-1,1)))\n",
    "\n",
    "            #PAC-Net\n",
    "            # Calibration\n",
    "            print('  --- Calibration')\n",
    "            PAC_model = PACModel()\n",
    "            _ = PAC_model.forward(x_target)\n",
    "            for j in range(len(PAC_model.layers)-1):\n",
    "                PAC_model.layers[j].kernel = prune_model.layers[j].kernel\n",
    "                PAC_model.layers[j].bias = prune_model.layers[j].bias\n",
    "            PAC_model.layers[4].kernel = tf.Variable(tf.reshape(tf.reduce_mean(prune_model.layers[4].kernel, axis=1), shape=(16,1)), name='Dense5/kernel')\n",
    "            PAC_model.layers[4].bias = tf.Variable(tf.reshape(tf.reduce_mean(prune_model.layers[4].bias), shape=(1, )), name='Dense5/bias')\n",
    "            _ = PAC_model.forward(x_target)\n",
    "            PAC_weights = PAC_model.get_weights()\n",
    "            \n",
    "            wu = [0,0,0,0,0]\n",
    "            for i_layer in range(len(wu)-1):\n",
    "                wu[i_layer] = prune_model.layers[i_layer].kernel.numpy()\n",
    "            wu[4] = tf.Variable(tf.reshape(tf.reduce_mean(prune_model.layers[4].kernel, axis=1), shape=(16,1)), name='Dense5/kernel').numpy()\n",
    "            \n",
    "            SearchParams_PAC_NN = {\n",
    "                'learning_rate' : [0.01],\n",
    "                'epochs' : [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 20, 50, 100],\n",
    "            }\n",
    "\n",
    "            gsr_PAC = GridSearchCV(\n",
    "                cls_PAC_NN(),\n",
    "                SearchParams_PAC_NN,\n",
    "                scoring='neg_mean_squared_error',\n",
    "                cv = 5,\n",
    "                n_jobs = -1,\n",
    "                verbose=False\n",
    "            )\n",
    "            t_tmp = time.time()\n",
    "            fix_seed(373)\n",
    "            gsr_PAC.fit(x_train_scal_tmp, y_train_scal_tmp)\n",
    "\n",
    "            model_PAC = cls_PAC_NN(\n",
    "                learning_rate = gsr_PAC.best_params_['learning_rate'],\n",
    "                epochs        = gsr_PAC.best_params_['epochs']\n",
    "            )\n",
    "            fix_seed(373)\n",
    "            model_PAC.fit(x_train_scal_tmp, y_train_scal_tmp)\n",
    "            y_fits_PAC = model_PAC.predict(x_train_scal_tmp)*y_std_tmp + y_mean_tmp\n",
    "            y_pred_PAC = model_PAC.predict(x_test_scal_tmp)*y_std_tmp + y_mean_tmp\n",
    "            print('   PAC has been done.    '+str(time.time()-t_tmp))\n",
    "\n",
    "            # Save results\n",
    "            if not os.path.isdir('../30_Output/20_Plot/340_TransferLearning/'+target_name+'/n'+str(num_train)):\n",
    "                os.makedirs('../30_Output/20_Plot/340_TransferLearning/'+target_name+'/n'+str(num_train))\n",
    "            ## Plot\n",
    "            plot_scatter(\n",
    "                y_obs1_list = [y_train_tmp.values],\n",
    "                y_prd1_list = [y_fits_PAC.reshape(-1)],\n",
    "                y_obs2_list = [y_test.values],\n",
    "                y_prd2_list = [y_pred_PAC.reshape(-1)],\n",
    "                title_list  = ['PAC'],\n",
    "                plt_row     = 1,\n",
    "                plt_col     = 1,\n",
    "                position_list = [1],\n",
    "                col_list1   = ['steelblue'],\n",
    "                alpha_list1 = [1],\n",
    "                col_list2   = ['tomato'],\n",
    "                alpha_list2 = [1],\n",
    "                fig_size    = (5, 5),\n",
    "                save_name   =  '../30_Output/20_Plot/340_TransferLearning/'+target_name+'/n'+str(num_train)+'/'+'341_'+target_name+'_n'+str(num_train)+'_'+str(n_itr)+'.png',\n",
    "                title       = target_name,\n",
    "                show_flg    = False)\n",
    "\n",
    "            ## Dataframe\n",
    "            df_result = pd.concat([df_result,\n",
    "                                pd.DataFrame(np.array([\n",
    "                                    target_name, \n",
    "                                    num_train,\n",
    "                                    n_itr,\n",
    "                                    'PAC',\n",
    "                                    mean_squared_error(y_test, y_pred_PAC.reshape(-1)),\n",
    "                                    np.corrcoef(y_test, y_pred_PAC.reshape(-1))[0,1],\n",
    "                                    mean_absolute_error(y_test, y_pred_PAC.reshape(-1)),\n",
    "                                    r2_score(y_test, y_pred_PAC.reshape(-1))\n",
    "                                 ]).reshape(1, -1), columns=['data_name','n_sample','n_itr','type', 'MSE', 'Corr', 'MAE', 'R2'], \n",
    "                                index=[target_name+'_n'+str(num_train)+'_itr'+str(n_itr)+'_PAC'])], axis=0)\n",
    "            df_result.to_csv('../30_Output/30_csv/340_TransferLearning_Result.csv')\n",
    "\n",
    "            clear_output(True)\n",
    "            print(time.time()-t1, ' / ', time.time()-t0)\n",
    "clear_output(True)\n",
    "print(time.time()-t0)\n",
    "print('*** Succeeded ***')"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
