{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Python packages used in this code"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import pandas as pd\n",
    "import matplotlib\n",
    "import matplotlib.pyplot as plt\n",
    "import random\n",
    "from random import shuffle\n",
    "import os\n",
    "import pickle\n",
    "import time\n",
    "import sklearn\n",
    "import platform\n",
    "import sys\n",
    "from sklearn.kernel_ridge import KernelRidge\n",
    "from sklearn.gaussian_process.kernels import Matern, RBF\n",
    "from sklearn.base import BaseEstimator, RegressorMixin\n",
    "from sklearn.model_selection import train_test_split, GridSearchCV\n",
    "from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score\n",
    "from sklearn.linear_model import Ridge\n",
    "import warnings\n",
    "warnings.simplefilter('ignore')\n",
    "from IPython.display import clear_output\n",
    "from scipy import io\n",
    "import math\n",
    "import joblib\n",
    "\n",
    "## Keras\n",
    "import tensorflow\n",
    "from tensorflow import keras\n",
    "from tensorflow.keras.models import Sequential, Model, model_from_json, load_model\n",
    "from tensorflow.keras.layers import Dense, Input, Add, Lambda, Dropout, Subtract, Multiply, Concatenate, Dot, BatchNormalization, Activation, LeakyReLU, ReLU\n",
    "from tensorflow.keras.losses import mse\n",
    "import keras.backend as K\n",
    "from tensorflow.keras import regularizers\n",
    "from keras.regularizers import Regularizer\n",
    "os.environ['TF_FORCE_GPU_ALLOW_GROWTH'] = 'true'\n",
    "os.environ['TF_CPP_MIN_LOG_LEVEL']='2'\n",
    "tensorflow.get_logger().setLevel(\"ERROR\")\n",
    "\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\"\"\"\n",
    "Environments\n",
    "\n",
    "--Platform--\n",
    "OS : Windows-10-10.0.19044-SP0\n",
    "--Version--\n",
    "python :  3.9.12 (main, Apr  4 2022, 05:22:27) [MSC v.1916 64 bit (AMD64)]\n",
    "numpy : 1.23.1\n",
    "pandas : 1.4.3\n",
    "sklearn : 1.1.1\n",
    "\"\"\"\n",
    "\n",
    "print('--Platform--')\n",
    "print('OS :', platform.platform())\n",
    "print('--Version--')\n",
    "print('python : ', sys.version)\n",
    "print('numpy :', np.__version__)\n",
    "print('pandas :', pd.__version__)\n",
    "print('sklearn :', sklearn.__version__)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Preparation"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Define the model class proposed in the paper"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class CustomRegularizer(Regularizer):\n",
    "    def __init__(self, w0, l1=0.01, l2=0.01):\n",
    "        self.w0 = K.variable(value=w0)\n",
    "        self.l1 = l1\n",
    "        self.l2 = l2\n",
    "\n",
    "    def __call__(self, x):\n",
    "        return self.l1 * K.square(x - self.w0) + self.l2 * k.square(x)\n",
    "\n",
    "    def get_config(self):\n",
    "        return {\"w0\": self.w0, \"l1\": self.l, \"l2\": self.l2}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class cls_L2SP_NN(BaseEstimator, RegressorMixin):\n",
    "    def __init__(self, learning_rate=0.01, epochs=10, n_frozen=0, lambda1=0.1, lambda2=0.1):\n",
    "        self.learning_rate = learning_rate\n",
    "        self.epochs = epochs\n",
    "        self.n_frozen = n_frozen\n",
    "        self.lambda1 = lambda1\n",
    "        self.lambda2 = lambda2\n",
    "    \n",
    "    def fit(self, X, y=None):\n",
    "        self.source_model = model_from_json(open(source_model_path_json, 'r').read())\n",
    "        self.source_model.load_weights(source_model_path_hdf5)\n",
    "        \n",
    "        h = Dense(units=1, name='Dense_add')(self.source_model.layers[8].output)\n",
    "        self.model = Model(inputs=self.source_model.inputs, outputs=h)\n",
    "        self.model.layers[9].set_weights([\n",
    "                self.source_model.layers[9].get_weights()[0].mean(1).reshape(-1,1),\n",
    "                self.source_model.layers[9].get_weights()[1].mean().reshape(1,)\n",
    "        ])\n",
    "\n",
    "        self.ref_model = Model(inputs=self.source_model.inputs, outputs=h)\n",
    "        self.ref_model.layers[9].set_weights([\n",
    "                self.source_model.layers[9].get_weights()[0].mean(1).reshape(-1,1),\n",
    "                self.source_model.layers[9].get_weights()[1].mean().reshape(1,)\n",
    "        ])\n",
    "        \n",
    "        for i_layer in [1, 3, 5, 7, 9]:\n",
    "            self.model.layers[i_layer].kernel_regularizer = CustomRegularizer(w0=self.ref_model.layers[i_layer].get_weights()[0], l1=self.lambda1, l2=self.lambda2)\n",
    "\n",
    "        self.model.compile(loss='mse',\n",
    "                     # optimizer=keras.optimizers.Adam(self.learning_rate),\n",
    "                     optimizer=keras.optimizers.Adagrad(self.learning_rate),\n",
    "                     metrics=['mae', 'mse'])\n",
    "        \n",
    "        fix_seed(373)\n",
    "        self.history = self.model.fit(\n",
    "            X,\n",
    "            y,\n",
    "            batch_size=1,\n",
    "            epochs=self.epochs,\n",
    "            validation_split = 0,\n",
    "            verbose=0\n",
    "        )\n",
    "        K.clear_session()\n",
    "        return self\n",
    "    \n",
    "    def predict(self, X):\n",
    "        return self.model.predict(X, verbose=0)\n",
    "\n",
    "    def score(self, X, y=None):\n",
    "        return -np.sum((y.values - self.predict(X))**2)\n",
    "    \n",
    "    def get_params(self, deep=True):\n",
    "        return {\n",
    "            'learning_rate' : self.learning_rate,\n",
    "            'epochs' : self.epochs\n",
    "        }\n",
    "    \n",
    "    def set_params(self, **parameters):\n",
    "        for parameter, value in parameters.items():\n",
    "            setattr(self, parameter, value)\n",
    "        return self  "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## fix_seed function"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def fix_seed(seed):\n",
    "    # Numpy\n",
    "    np.random.seed(seed)\n",
    "    # Tensorflow\n",
    "    tensorflow.random.set_seed(seed)\n",
    "    # for built-in random\n",
    "    random.seed(seed)\n",
    "    # for hash seed\n",
    "    os.environ[\"PYTHONHASHSEED\"] = str(seed)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Plot function"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "def plot_scatter(y_obs1_list, \n",
    "                 y_prd1_list,\n",
    "                 y_obs2_list, \n",
    "                 y_prd2_list,\n",
    "                 title_list, \n",
    "                 plt_row, \n",
    "                 plt_col, \n",
    "                 position_list, \n",
    "                 col_list1,\n",
    "                 alpha_list1,\n",
    "                 col_list2,\n",
    "                 alpha_list2,\n",
    "                 fig_size, \n",
    "                 save_name, \n",
    "                 title, \n",
    "                 show_flg=True):\n",
    "    fig = plt.figure(figsize=fig_size)\n",
    "\n",
    "    for i_plt in range(len(position_list)):\n",
    "        ax = fig.add_subplot(plt_row, plt_col, position_list[i_plt], \n",
    "                             title=title_list[i_plt], \n",
    "                             xlabel='Observation', \n",
    "                             ylabel='Prediction')\n",
    "        ax.scatter(y_obs1_list[i_plt], y_prd1_list[i_plt], color=col_list1[i_plt], alpha=alpha_list1[i_plt], zorder=10)\n",
    "        ax.scatter(y_obs2_list[i_plt], y_prd2_list[i_plt], color=col_list2[i_plt], alpha=alpha_list2[i_plt])\n",
    "        xy_min = min(ax.get_xlim()[0], ax.get_ylim()[0])\n",
    "        xy_max = max(ax.get_xlim()[1], ax.get_ylim()[1])\n",
    "        ax.axis('equal')\n",
    "        ax.axis('square')\n",
    "        ax.set_xlim([xy_min, xy_max])\n",
    "        ax.set_ylim([xy_min, xy_max])\n",
    "        ax.grid(color='gray', linestyle='dotted', linewidth=1, alpha=0.5)\n",
    "        ax.text(0.03, 0.93, 'Corr : '+str(round(np.corrcoef(y_prd2_list[i_plt], y_obs2_list[i_plt])[0,1], 4)), size=15, transform=ax.transAxes)\n",
    "        ax.text(0.03, 0.87, 'MSE : '+str(round(mean_squared_error(y_obs2_list[i_plt], y_prd2_list[i_plt]), 4)), size=15, transform=ax.transAxes)\n",
    "        ax.text(0.03, 0.81, 'MAE : '+str(round(mean_absolute_error(y_obs2_list[i_plt], y_prd2_list[i_plt]), 4)), size=15, transform=ax.transAxes)\n",
    "        _ = ax.plot([-300, 300], [-300, 300], color='gray', linewidth=0.5)\n",
    "\n",
    "    fig.tight_layout(rect=[0,0,1,0.90])\n",
    "    \n",
    "    plt.suptitle(title,fontsize=20)\n",
    "\n",
    "    fig.savefig(save_name)\n",
    "    if show_flg==False:\n",
    "        plt.close(fig)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Function to avoid zero division"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def avoid_zero(x, tsh=1, _add=0.1):\n",
    "    if np.abs(x) < tsh:\n",
    "        if x >= 0:\n",
    "            return tsh\n",
    "        else:\n",
    "            return -tsh\n",
    "    else:\n",
    "        return x"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Main codes"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Load data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_path = '../10_Data/sarcos_inv.mat'\n",
    "test_path = '../10_Data/sarcos_inv_test.mat'\n",
    "axis_names = ['Position1','Position2','Position3','Position4','Position5','Position6','Position7',\n",
    "              'Velocity1','Velocity2','Velocity3','Velocity4','Velocity5','Velocity6','Velocity7',\n",
    "              'Acceleration1','Acceleration2','Acceleration3','Acceleration4','Acceleration5','Acceleration6','Acceleration7',\n",
    "              'Torque1','Torque2','Torque3','Torque4','Torque5','Torque6','Torque7']\n",
    "\n",
    "sar_train_all = io.loadmat(train_path)\n",
    "sar_test = io.loadmat(test_path)\n",
    "sar_train_all = pd.DataFrame(sar_train_all['sarcos_inv'], columns=axis_names)\n",
    "sar_test = pd.DataFrame(sar_test['sarcos_inv_test'], columns=axis_names)\n",
    "sar_train = sar_train_all.iloc[:30000, :]\n",
    "\n",
    "x_train = sar_train.iloc[:,0:21]\n",
    "x_test = sar_test.iloc[:,0:21]\n",
    "y_train_all = sar_train.iloc[:,21:]\n",
    "y_test_all = sar_test.iloc[:,21:]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## User parameter setting"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fix_seed(373)\n",
    "target_name_list = ['Torque1','Torque2','Torque3','Torque4','Torque5','Torque6','Torque7']\n",
    "n_sample_list = [5, 10, 15, 20, 30, 40, 50]\n",
    "max_itr = 20\n",
    "\n",
    "dim_x = 21\n",
    "n_all = 30000\n",
    "num_SourceTasks = 6\n",
    "\n",
    "# Kernel setting\n",
    "kernel_name = 'rbf'\n",
    "nu_ = 1.5\n",
    "i_seed = 373"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Make training sample ID list"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "fix_seed(373)\n",
    "sample_list = list()\n",
    "for n_try in range(max_itr):\n",
    "    fix_seed(n_try)\n",
    "    tmp_list = list(range(n_all))\n",
    "    random.shuffle(tmp_list)\n",
    "    sample_list.append(tmp_list)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "# Storing dataframe\n",
    "df_result = pd.DataFrame(columns=['data_name','n_sample','n_itr','type', 'MSE', 'Corr', 'MAE', 'R2'])\n",
    "\n",
    "t0 = time.time()\n",
    "# Repeat for the different number of samples\n",
    "for num_train in n_sample_list:\n",
    "\n",
    "# Repeat for the different target torques\n",
    "    for target_name in target_name_list:\n",
    "\n",
    "        # Predicted data\n",
    "        predicted_data = joblib.load('../30_Output/40_pkl/100_MakeSourceModel/110_Prediction_for_'+target_name+'-'+str(i_seed)+'.pkl')\n",
    "        s_train = predicted_data['output_train']\n",
    "        s_test = predicted_data['output_test']\n",
    "        f_train = predicted_data['feature_train']\n",
    "        f_test = predicted_data['feature_test']\n",
    "\n",
    "        # Make target outputs and source features\n",
    "        y_train = y_train_all[target_name].copy()\n",
    "        y_test = y_test_all[target_name].copy()\n",
    "\n",
    "        source_model_path_json = '../30_Output/10_Model/100_MakeSourceModel/100_Model_'+target_name+'-'+str(i_seed)+'.json'\n",
    "        source_model_path_hdf5 = '../30_Output/10_Model/100_MakeSourceModel/100_Model_'+target_name+'-'+str(i_seed)+'.hdf5'\n",
    "        source_model = model_from_json(open(source_model_path_json, 'r').read())\n",
    "        source_model.load_weights(source_model_path_hdf5)\n",
    "\n",
    "\n",
    "        name_layer = source_model.layers[9].name\n",
    "        model_Ext = Model(inputs=source_model.input, outputs=source_model.get_layer(name_layer).output)\n",
    "\n",
    "        # Repeat for the different sample splits\n",
    "        for n_itr in range(max_itr):\n",
    "            print(target_name+'   n : '+str(num_train)+',  try : '+str(n_itr))\n",
    "            t1 = time.time()\n",
    "\n",
    "            # Make training data\n",
    "            sample_id = sample_list[n_itr][:num_train]\n",
    "            x_train_tmp = x_train.iloc[sample_id,]\n",
    "            y_train_tmp = y_train.iloc[sample_id,]\n",
    "            s_train_tmp = s_train.iloc[sample_id,]\n",
    "            f_train_tmp = f_train.iloc[sample_id,]\n",
    "\n",
    "            # Scaling parameters\n",
    "            ## Inputs\n",
    "            x_mean_tmp = x_train_tmp.mean()\n",
    "            x_std_tmp = x_train_tmp.std()\n",
    "            x_train_scal_tmp = (x_train_tmp - x_mean_tmp) / x_std_tmp.replace(0,1)\n",
    "            x_test_scal_tmp = (x_test - x_mean_tmp) / x_std_tmp.replace(0,1)\n",
    "            ## Outputs\n",
    "            y_mean_tmp = y_train_tmp.mean()\n",
    "            y_std_tmp = y_train_tmp.std()\n",
    "            y_train_scal_tmp = (y_train_tmp - y_mean_tmp) / y_std_tmp\n",
    "            y_test_scal_tmp = (y_test - y_mean_tmp) / y_std_tmp\n",
    "\n",
    "            #L2-SP\n",
    "            SearchParams_L2SP_NN = {\n",
    "                'learning_rate' : [1e-3],\n",
    "                'epochs' : [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 20, 50, 100],\n",
    "                'n_frozen' : [0],\n",
    "                'lambda1' : [0.01],\n",
    "                'lambda2' : [0.01]\n",
    "            }\n",
    "\n",
    "            gsr_L2SP = GridSearchCV(\n",
    "                cls_L2SP_NN(),\n",
    "                SearchParams_L2SP_NN,\n",
    "                scoring='neg_mean_squared_error',\n",
    "                cv = 5,\n",
    "                n_jobs = -1,\n",
    "                verbose=False\n",
    "            )\n",
    "            t_tmp = time.time()\n",
    "            fix_seed(373)\n",
    "            gsr_L2SP.fit(x_train_scal_tmp, y_train_scal_tmp)\n",
    "\n",
    "            model_L2SP = cls_L2SP_NN(\n",
    "                learning_rate = gsr_L2SP.best_params_['learning_rate'],\n",
    "                epochs        = gsr_L2SP.best_params_['epochs'],\n",
    "                n_frozen      = gsr_L2SP.best_params_['n_frozen'],\n",
    "                lambda1       = gsr_L2SP.best_params_['lambda1'],\n",
    "                lambda2       = gsr_L2SP.best_params_['lambda2']\n",
    "            )\n",
    "            fix_seed(373)\n",
    "            model_L2SP.fit(x_train_scal_tmp, y_train_scal_tmp)\n",
    "            y_fits_L2SP = model_L2SP.predict(x_train_scal_tmp)*y_std_tmp + y_mean_tmp\n",
    "            y_pred_L2SP = model_L2SP.predict(x_test_scal_tmp)*y_std_tmp + y_mean_tmp\n",
    "            print('   L2-SP has been done.    '+str(time.time()-t_tmp))\n",
    "\n",
    "            # Save results\n",
    "            if not os.path.isdir('../30_Output/20_Plot/330_TransferLearning/'+target_name+'/n'+str(num_train)):\n",
    "                os.makedirs('../30_Output/20_Plot/330_TransferLearning/'+target_name+'/n'+str(num_train))\n",
    "            ## Plot\n",
    "            plot_scatter(\n",
    "                y_obs1_list = [y_train_tmp.values],\n",
    "                y_prd1_list = [y_fits_L2SP.reshape(-1)],\n",
    "                y_obs2_list = [y_test.values],\n",
    "                y_prd2_list = [y_pred_L2SP.reshape(-1)],\n",
    "                title_list  = ['L2-SP'],\n",
    "                plt_row     = 1,\n",
    "                plt_col     = 1,\n",
    "                position_list = [1],\n",
    "                col_list1   = ['steelblue'],\n",
    "                alpha_list1 = [1],\n",
    "                col_list2   = ['tomato'],\n",
    "                alpha_list2 = [1],\n",
    "                fig_size    = (5, 5),\n",
    "                save_name   =  '../30_Output/20_Plot/330_TransferLearning/'+target_name+'/n'+str(num_train)+'/'+'331_'+target_name+'_n'+str(num_train)+'_'+str(n_itr)+'.png',\n",
    "                title       = target_name,\n",
    "                show_flg    = False)\n",
    "\n",
    "            ## Dataframe\n",
    "            df_result = pd.concat([df_result,\n",
    "                                pd.DataFrame(np.array([\n",
    "                                    target_name, \n",
    "                                    num_train,\n",
    "                                    n_itr,\n",
    "                                    'L2-SP',\n",
    "                                    mean_squared_error(y_test, y_pred_L2SP.reshape(-1)),\n",
    "                                    np.corrcoef(y_test, y_pred_L2SP.reshape(-1))[0,1],\n",
    "                                    mean_absolute_error(y_test, y_pred_L2SP.reshape(-1)),\n",
    "                                    r2_score(y_test, y_pred_L2SP.reshape(-1))\n",
    "                                 ]).reshape(1, -1), columns=['data_name','n_sample','n_itr','type', 'MSE', 'Corr', 'MAE', 'R2'], \n",
    "                                index=[target_name+'_n'+str(num_train)+'_itr'+str(n_itr)+'_L2SP'])], axis=0)\n",
    "            df_result.to_csv('../30_Output/30_csv/330_TransferLearning_Result.csv')\n",
    "\n",
    "            clear_output(True)\n",
    "            print(time.time()-t1, ' / ', time.time()-t0)\n",
    "clear_output(True)\n",
    "print(time.time()-t0)\n",
    "print('*** Succeeded ***')"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
