{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Python packages used in this code"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import pandas as pd\n",
    "import matplotlib\n",
    "import matplotlib.pyplot as plt\n",
    "import random\n",
    "import os\n",
    "import pickle\n",
    "import time\n",
    "import sklearn\n",
    "import platform\n",
    "import sys\n",
    "from sklearn.base import BaseEstimator, RegressorMixin\n",
    "from sklearn.model_selection import train_test_split, GridSearchCV\n",
    "from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score\n",
    "import warnings\n",
    "warnings.simplefilter('ignore')\n",
    "from IPython.display import clear_output\n",
    "from scipy import io\n",
    "import joblib\n",
    "\n",
    "## Keras\n",
    "import tensorflow\n",
    "from tensorflow import keras\n",
    "from tensorflow.keras.models import Sequential, Model, model_from_json, load_model\n",
    "from tensorflow.keras.layers import Dense, Input, Add, Lambda, Dropout, Subtract, Multiply, Concatenate, Dot, BatchNormalization, Activation, LeakyReLU, ReLU\n",
    "from tensorflow.keras.losses import mse\n",
    "import keras.backend as keras_backend\n",
    "from tensorflow.keras import regularizers\n",
    "os.environ['TF_FORCE_GPU_ALLOW_GROWTH'] = 'true'\n",
    "os.environ['TF_CPP_MIN_LOG_LEVEL']='2'\n",
    "tensorflow.get_logger().setLevel(\"ERROR\")\n",
    "import tensorflow as tf\n",
    "\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\"\"\"\n",
    "Environments\n",
    "\n",
    "--Platform--\n",
    "OS : Windows-10-10.0.19044-SP0\n",
    "--Version--\n",
    "python :  3.9.12 (main, Apr  4 2022, 05:22:27) [MSC v.1916 64 bit (AMD64)]\n",
    "numpy : 1.23.1\n",
    "pandas : 1.4.3\n",
    "\"\"\"\n",
    "\n",
    "print('--Platform--')\n",
    "print('OS :', platform.platform())\n",
    "print('--Version--')\n",
    "print('python : ', sys.version)\n",
    "print('numpy :', np.__version__)\n",
    "print('pandas :', pd.__version__)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Preparation"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Define the model class proposed in the paper"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## fix_seed function"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def fix_seed(seed):\n",
    "    # Numpy\n",
    "    np.random.seed(seed)\n",
    "    # Tensorflow\n",
    "    tensorflow.random.set_seed(seed)\n",
    "    # for built-in random\n",
    "    random.seed(seed)\n",
    "    # for hash seed\n",
    "    os.environ[\"PYTHONHASHSEED\"] = str(seed)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Plot function"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "def plot_scatter(y_obs1_list, \n",
    "                 y_prd1_list,\n",
    "                 y_obs2_list, \n",
    "                 y_prd2_list,\n",
    "                 title_list, \n",
    "                 plt_row, \n",
    "                 plt_col, \n",
    "                 position_list, \n",
    "                 col_list1,\n",
    "                 alpha_list1,\n",
    "                 col_list2,\n",
    "                 alpha_list2,\n",
    "                 fig_size, \n",
    "                 save_name, \n",
    "                 title, \n",
    "                 show_flg=True):\n",
    "    fig = plt.figure(figsize=fig_size)\n",
    "\n",
    "    for i_plt in range(len(position_list)):\n",
    "        ax = fig.add_subplot(plt_row, plt_col, position_list[i_plt], \n",
    "                             title=title_list[i_plt], \n",
    "                             xlabel='Observation', \n",
    "                             ylabel='Prediction')\n",
    "        ax.scatter(y_obs1_list[i_plt], y_prd1_list[i_plt], color=col_list1[i_plt], alpha=alpha_list1[i_plt], zorder=10)\n",
    "        ax.scatter(y_obs2_list[i_plt], y_prd2_list[i_plt], color=col_list2[i_plt], alpha=alpha_list2[i_plt])\n",
    "        xy_min = min(ax.get_xlim()[0], ax.get_ylim()[0])\n",
    "        xy_max = max(ax.get_xlim()[1], ax.get_ylim()[1])\n",
    "        ax.axis('equal')\n",
    "        ax.axis('square')\n",
    "        ax.set_xlim([xy_min, xy_max])\n",
    "        ax.set_ylim([xy_min, xy_max])\n",
    "        ax.grid(color='gray', linestyle='dotted', linewidth=1, alpha=0.5)\n",
    "        ax.text(0.03, 0.93, 'Corr : '+str(round(np.corrcoef(y_prd2_list[i_plt], y_obs2_list[i_plt])[0,1], 4)), size=15, transform=ax.transAxes)\n",
    "        ax.text(0.03, 0.87, 'MSE : '+str(round(mean_squared_error(y_obs2_list[i_plt], y_prd2_list[i_plt]), 4)), size=15, transform=ax.transAxes)\n",
    "        ax.text(0.03, 0.81, 'MAE : '+str(round(mean_absolute_error(y_obs2_list[i_plt], y_prd2_list[i_plt]), 4)), size=15, transform=ax.transAxes)\n",
    "        _ = ax.plot([-300, 300], [-300, 300], color='gray', linewidth=0.5)\n",
    "\n",
    "    fig.tight_layout(rect=[0,0,1,0.90])\n",
    "    \n",
    "    plt.suptitle(title,fontsize=20)\n",
    "\n",
    "    fig.savefig(save_name)\n",
    "    if show_flg==False:\n",
    "        plt.close(fig)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class NNModel(keras.Model):\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "        self.hidden1 = keras.layers.Dense(256, input_shape=(21,))\n",
    "        self.hidden2 = keras.layers.Dense(64)\n",
    "        self.hidden3 = keras.layers.Dense(32)\n",
    "        self.hidden4 = keras.layers.Dense(16)\n",
    "        self.out = keras.layers.Dense(1)\n",
    "        \n",
    "    def forward(self, x):\n",
    "        x = keras.activations.relu(self.hidden1(x))\n",
    "        x = Dropout(0.1)(x)\n",
    "        x = keras.activations.relu(self.hidden2(x))\n",
    "        x = Dropout(0.1)(x)\n",
    "        x = keras.activations.relu(self.hidden3(x))\n",
    "        x = keras.activations.relu(self.hidden4(x))\n",
    "        x = self.out(x)\n",
    "        return x\n",
    "\n",
    "def loss_function(pred_y, y):\n",
    "    return keras_backend.mean(keras.losses.mean_squared_error(y, pred_y))\n",
    "\n",
    "def np_to_tensor(list_of_numpy_objs):\n",
    "    return (tf.convert_to_tensor(obj) for obj in list_of_numpy_objs)\n",
    "\n",
    "def compute_loss(model, x, y, loss_fn=loss_function):\n",
    "    logits = model.forward(x)\n",
    "    mse = loss_fn(y, logits)\n",
    "    return mse, logits\n",
    "\n",
    "\n",
    "def compute_gradients(model, x, y, loss_fn=loss_function):\n",
    "    with tf.GradientTape() as tape:\n",
    "        loss, _ = compute_loss(model, x, y, loss_fn)\n",
    "    return tape.gradient(loss, model.trainable_variables), loss\n",
    "\n",
    "\n",
    "def apply_gradients(optimizer, gradients, variables):\n",
    "    optimizer.apply_gradients(zip(gradients, variables))\n",
    "\n",
    "\n",
    "def train_batch(x, y, model, optimizer):\n",
    "    tensor_x, tensor_y = np_to_tensor((x, y))\n",
    "    gradients, loss = compute_gradients(model, tensor_x, tensor_y)\n",
    "    apply_gradients(optimizer, gradients, model.trainable_variables)\n",
    "    return loss\n",
    "\n",
    "def train_model(dataset, epochs=1, lr=0.001, log_steps=1000):\n",
    "    model = NNModel()\n",
    "    optimizer = keras.optimizers.Adam(learning_rate=lr)\n",
    "    for epoch in range(epochs):\n",
    "        losses = []\n",
    "        total_loss = 0\n",
    "        start = time.time()\n",
    "        for i, sinusoid_generator in enumerate(dataset):\n",
    "            x, y = sinusoid_generator.batch()\n",
    "            loss = train_batch(x, y, model, optimizer)\n",
    "            total_loss += loss\n",
    "            curr_loss = total_loss / (i + 1.0)\n",
    "            losses.append(curr_loss)\n",
    "\n",
    "            if i % log_steps == 0 and i > 0:\n",
    "                print('Step {}: loss = {}, Time to run {} steps = {:.2f} seconds'.format(\n",
    "                    i, curr_loss, log_steps, time.time() - start))\n",
    "                start = time.time()\n",
    "        plt.plot(losses)\n",
    "        plt.title('Loss Vs Time steps')\n",
    "        plt.show()\n",
    "    return model\n",
    "\n",
    "def copy_model(model, x):\n",
    "    '''Copy model weights to a new model.\n",
    "    \n",
    "    Args:\n",
    "        model: model to be copied.\n",
    "        x: An input example. This is used to run\n",
    "            a forward pass in order to add the weights of the graph\n",
    "            as variables.\n",
    "    Returns:\n",
    "        A copy of the model.\n",
    "    '''\n",
    "    copied_model = NNModel()\n",
    "    \n",
    "    # If we don't run this step the weights are not \"initialized\"\n",
    "    # and the gradients will not be computed.\n",
    "    copied_model.forward(tf.convert_to_tensor(x))\n",
    "    \n",
    "    copied_model.set_weights(model.get_weights())\n",
    "    return copied_model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def eval_test_fit(model, optimizer, x, y, x_test, y_test, num_steps=(0, 1, 10)):\n",
    "    '''Evaluate how the model fits to the curve training for `fits` steps.\n",
    "    \n",
    "    Args:\n",
    "        model: Model evaluated.\n",
    "        optimizer: Optimizer to be for training.\n",
    "        x: Data used for training.\n",
    "        y: Targets used for training.\n",
    "        x_test: Data used for evaluation.\n",
    "        y_test: Targets used for evaluation.\n",
    "        num_steps: Number of steps to log.\n",
    "    '''\n",
    "    fit_res = []\n",
    "    \n",
    "    tensor_x_test, tensor_y_test = np_to_tensor((x_test, y_test))\n",
    "    tensor_x, tensor_y = np_to_tensor((x, y))\n",
    "    \n",
    "    # If 0 in fits we log the loss before any training\n",
    "    if 0 in num_steps:\n",
    "        loss_train, logits_train = compute_loss(model, tensor_x, tensor_y)\n",
    "        loss, logits = compute_loss(model, tensor_x_test, tensor_y_test)\n",
    "        fit_res.append((0, logits_train, loss_train, logits, loss))\n",
    "        \n",
    "    for step in range(1, np.max(num_steps) + 1):\n",
    "        fix_seed(i_seed)\n",
    "        train_batch(x, y, model, optimizer)\n",
    "        loss_train, logits_train = compute_loss(model, tensor_x, tensor_y)\n",
    "        loss, logits = compute_loss(model, tensor_x_test, tensor_y_test)\n",
    "        if step in num_steps:\n",
    "            fit_res.append(\n",
    "                (\n",
    "                    step, \n",
    "                    logits_train,\n",
    "                    loss_train,\n",
    "                    logits,\n",
    "                    loss\n",
    "                )\n",
    "            )\n",
    "    return fit_res\n",
    "\n",
    "\n",
    "def eval_test(model, x, y, x_test, y_test, num_steps=(0, 1, 10), lr=0.01):\n",
    "    '''Evaluates how the sinewave addapts at dataset.\n",
    "    \n",
    "    The idea is to use the pretrained model as a weight initializer and\n",
    "    try to fit the model on this new dataset.\n",
    "    \n",
    "    Args:\n",
    "        model: Already trained model.\n",
    "        sinusoid_generator: A sinusoidGenerator instance.\n",
    "        num_steps: Number of training steps to be logged.\n",
    "        lr: Learning rate used for training on the test data.\n",
    "        plot: If plot is True than it plots how the curves are fitted along\n",
    "            `num_steps`.\n",
    "    \n",
    "    Returns:\n",
    "        The fit results. A list containing the loss, logits and step. For\n",
    "        every step at `num_steps`.\n",
    "    '''\n",
    "    x, y = np_to_tensor((x, y))\n",
    "    x_test, y_test = np_to_tensor((x_test, y_test))\n",
    "    \n",
    "    # copy model so we can use the same model multiple times\n",
    "    copied_model = copy_model(model, x)\n",
    "    \n",
    "    # use SGD for this part of training as described in the paper\n",
    "    optimizer = keras.optimizers.Adam()\n",
    "    \n",
    "    # run training and log fit results\n",
    "    fix_seed(i_seed)\n",
    "    fit_res = eval_test_fit(copied_model, optimizer, x, y, x_test, y_test, num_steps)\n",
    "    \n",
    "    return fit_res"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def train_maml(model, epochs, dataset, dataset_val, lr_inner=0.01, batch_size=1, log_steps=5, patience=1):\n",
    "    '''Train using the MAML setup.\n",
    "    \n",
    "    The comments in this function that start with:\n",
    "        \n",
    "        Step X:\n",
    "        \n",
    "    Refer to a step described in the Algorithm 1 of the paper.\n",
    "    \n",
    "    Args:\n",
    "        model: A model.\n",
    "        epochs: Number of epochs used for training.\n",
    "        dataset: A dataset used for training.\n",
    "        lr_inner: Inner learning rate (alpha in Algorithm 1). Default value is 0.01.\n",
    "        batch_size: Batch size. Default value is 1. The paper does not specify\n",
    "            which value they use.\n",
    "        log_steps: At every `log_steps` a log message is printed.\n",
    "    \n",
    "    Returns:\n",
    "        A strong, fully-developed and trained maml.\n",
    "    '''\n",
    "    optimizer = keras.optimizers.Adam()\n",
    "    \n",
    "    # Step 2: instead of checking for convergence, we train for a number\n",
    "    # of epochs\n",
    "    losses = []\n",
    "    val_losses = []\n",
    "    start = time.time()\n",
    "    ES_count = 0\n",
    "    for ep in range(epochs):\n",
    "        total_loss = 0\n",
    "        # Step 3 and 4\n",
    "        for i, t in enumerate(dataset):\n",
    "            x, y = np_to_tensor(t)\n",
    "            model.forward(x)  # run forward pass to initialize weights\n",
    "            with tf.GradientTape() as test_tape:\n",
    "                # test_tape.watch(model.trainable_variables)\n",
    "                # Step 5\n",
    "                with tf.GradientTape() as train_tape:\n",
    "                    train_loss, _ = compute_loss(model, x, y)\n",
    "                # Step 6\n",
    "                gradients = train_tape.gradient(train_loss, model.trainable_variables)\n",
    "                k = 0\n",
    "                model_copy = copy_model(model, x)\n",
    "                for j in range(len(model_copy.layers)):\n",
    "                    model_copy.layers[j].kernel = tf.subtract(model.layers[j].kernel,\n",
    "                                tf.multiply(lr_inner, gradients[k]))\n",
    "                    model_copy.layers[j].bias = tf.subtract(model.layers[j].bias,\n",
    "                                tf.multiply(lr_inner, gradients[k+1]))\n",
    "                    k += 2\n",
    "                # Step 8\n",
    "                test_loss, logits = compute_loss(model_copy, x, y)\n",
    "            # Step 8\n",
    "            gradients = test_tape.gradient(test_loss, model.trainable_variables)\n",
    "            optimizer.apply_gradients(zip(gradients, model.trainable_variables))\n",
    "            \n",
    "            # Logs\n",
    "            total_loss += test_loss\n",
    "        losses.append(total_loss/len(dataset))\n",
    "            \n",
    "        # validation\n",
    "        total_val_loss = 0\n",
    "        for i, t in enumerate(dataset):\n",
    "            model_copy_val = copy_model(model, x)\n",
    "            x, y = np_to_tensor(t)\n",
    "            gradients, loss = compute_gradients(model_copy_val, x, y)\n",
    "            apply_gradients(optimizer, gradients, model_copy_val.trainable_variables)\n",
    "            \n",
    "            x_val, y_val = dataset_val[i]\n",
    "            val_loss, val_logits = compute_loss(model_copy_val, x_val, y_val)\n",
    "            total_val_loss += val_loss\n",
    "        val_losses.append(total_val_loss/len(dataset))\n",
    "        \n",
    "        if ep > 1 and min(val_losses[:-1])<=val_losses[-1]:\n",
    "            ES_count += 1\n",
    "            if ES_count >= patience:\n",
    "                break\n",
    "        else:\n",
    "            ES_count = 0\n",
    "            \n",
    "        if ep % log_steps == 0 and i > 0:\n",
    "            print('Epoch {}: loss = {}, val_loss = {}, Time to run {} epochs = {}'.format(ep, losses[-1], val_losses[-1], log_steps, time.time() - start))\n",
    "            start = time.time()\n",
    "            \n",
    "    clear_output(True)\n",
    "    return model, losses, val_losses"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class MAML_fit(BaseEstimator, RegressorMixin):\n",
    "    def __init__(self, epochs=1, learning_rate=0.01):\n",
    "        \"\"\"\n",
    "        Create the target model using MAML pre-trained model\n",
    "        \n",
    "        Parameters\n",
    "        ----------\n",
    "            epochs        : number of epochs for the training the target model\n",
    "            learning_rate : learning rate for the optimizer\n",
    "            \n",
    "        \"\"\"\n",
    "        self.epochs = epochs\n",
    "        self.learning_rate = learning_rate\n",
    "\n",
    "    def fit(self, X, y=None):\n",
    "        \"\"\"\n",
    "        Model fitting\n",
    "        \n",
    "        Required grobal variables\n",
    "        -----------------------\n",
    "            maml : MAML pre-trained model\n",
    "        \n",
    "        Returns\n",
    "        -------\n",
    "            X             : descriptors (pandas dataframe)\n",
    "            y             : output (pandas series)\n",
    "        \"\"\"\n",
    "        \n",
    "        tensor_x, tensor_y = np_to_tensor((X.values, y.values.reshape(-1,1)))\n",
    "\n",
    "        # copy model so we can use the same model multiple times\n",
    "        self.model = copy_model(maml, tensor_x)\n",
    "\n",
    "        self.optimizer = keras.optimizers.Adam()\n",
    "\n",
    "        # run training and log fit results\n",
    "        fix_seed(i_seed)\n",
    "        self.fit_res = []\n",
    "        # we log the loss before any training\n",
    "        loss_train, logits_train = compute_loss(self.model, tensor_x, tensor_y)\n",
    "        self.fit_res.append((0, logits_train, loss_train))\n",
    "        \n",
    "        if self.epochs > 0:\n",
    "            for step in range(1, np.max(self.epochs) + 1):\n",
    "                fix_seed(i_seed)\n",
    "                train_batch(tensor_x, tensor_y, self.model, self.optimizer)\n",
    "                loss_train, logits_train = compute_loss(self.model, tensor_x, tensor_y)\n",
    "                self.fit_res.append(\n",
    "                    (\n",
    "                        step, \n",
    "                        logits_train,\n",
    "                        loss_train\n",
    "                    )\n",
    "                )\n",
    "\n",
    "        return self\n",
    "    \n",
    "    def predict(self, X):\n",
    "        \"\"\"\n",
    "        Prediction function\n",
    "            \n",
    "        Returns\n",
    "        -------\n",
    "            y_pred\n",
    "        \"\"\"\n",
    "        return self.model.forward(tf.convert_to_tensor(X)).numpy().reshape(-1)\n",
    "\n",
    "    def score(self, X, y=None):\n",
    "        \"\"\"\n",
    "        Score function for cross-validation\n",
    "        \n",
    "        \"\"\"\n",
    "        tmp_X, tmp_y = np_to_tensor((X.values, y.values.reshape(-1,1)))\n",
    "        loss, _ = compute_loss(self.model, tmp_X, tmp_y)\n",
    "        return -loss.numpy()\n",
    "    \n",
    "    def get_params(self, deep=True):\n",
    "        \"\"\"\n",
    "        Create parameter dictionary for cross-validation\n",
    "        \n",
    "        Returns\n",
    "        -------\n",
    "            {'gamma', 'lambda1', 'nu', 'kernel'}\n",
    "        \"\"\"\n",
    "        return {'epochs' : self.epochs,\n",
    "                'learning_rate' : self.learning_rate}\n",
    "    \n",
    "    def set_params(self, **parameters):\n",
    "        \"\"\"\n",
    "        For cross-validation\n",
    "        \"\"\"\n",
    "        for parameter, value in parameters.items():\n",
    "            setattr(self, parameter, value)\n",
    "        return self     "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Main codes"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Load data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_path = '../10_Data/sarcos_inv.mat'\n",
    "test_path = '../10_Data/sarcos_inv_test.mat'\n",
    "axis_names = ['Position1','Position2','Position3','Position4','Position5','Position6','Position7',\n",
    "              'Velocity1','Velocity2','Velocity3','Velocity4','Velocity5','Velocity6','Velocity7',\n",
    "              'Acceleration1','Acceleration2','Acceleration3','Acceleration4','Acceleration5','Acceleration6','Acceleration7',\n",
    "              'Torque1','Torque2','Torque3','Torque4','Torque5','Torque6','Torque7']\n",
    "\n",
    "sar_train_all = io.loadmat(train_path)\n",
    "sar_test = io.loadmat(test_path)\n",
    "sar_train_all = pd.DataFrame(sar_train_all['sarcos_inv'], columns=axis_names)\n",
    "sar_test = pd.DataFrame(sar_test['sarcos_inv_test'], columns=axis_names)\n",
    "sar_train = sar_train_all.iloc[:30000, :]\n",
    "\n",
    "x_train = sar_train.iloc[:,0:21]\n",
    "x_test = sar_test.iloc[:,0:21]\n",
    "y_train_all = sar_train.iloc[:,21:]\n",
    "y_test_all = sar_test.iloc[:,21:]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## User parameter setting"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "target_name_list = ['Torque1','Torque2','Torque3','Torque4','Torque5','Torque6','Torque7']\n",
    "n_sample_list = [5, 10, 15, 20, 30, 40, 50]\n",
    "max_itr = 20\n",
    "\n",
    "dim_x = 21\n",
    "n_all = 30000\n",
    "num_SourceTasks = 6\n",
    "\n",
    "# Kernel setting\n",
    "kernel_name = 'rbf'\n",
    "nu_ = 1.5\n",
    "i_seed = 373"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Make training sample ID list"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "fix_seed(373)\n",
    "sample_list = list()\n",
    "for n_try in range(max_itr):\n",
    "    fix_seed(n_try)\n",
    "    tmp_list = list(range(n_all))\n",
    "    random.shuffle(tmp_list)\n",
    "    sample_list.append(tmp_list)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Storing dataframe\n",
    "df_result = pd.DataFrame(columns=['data_name','n_sample','n_itr','type', 'MSE', 'Corr', 'MAE', 'R2', 'epochs'])\n",
    "\n",
    "t0 = time.time()\n",
    "# Repeat for the different target torques\n",
    "for target_name in target_name_list:\n",
    "\n",
    "    x_mean = x_train.mean()\n",
    "    x_std = x_train.std()\n",
    "    x_train_scal = (x_train - x_mean)/x_std\n",
    "    x_test_scal = (x_test - x_mean)/x_std\n",
    "\n",
    "    y_mean = y_train_all.mean()\n",
    "    y_std = y_train_all.std()\n",
    "    y_train_all_scal = (y_train_all - y_mean)/y_std\n",
    "    y_test_all_scal = (y_test_all - y_mean)/y_std\n",
    "\n",
    "    y_train = y_train_all[target_name].copy()\n",
    "    y_test = y_test_all[target_name].copy()\n",
    "\n",
    "    dataset     = [(x_train_scal.values, y_train_all_scal[t_name].values.reshape(-1,1)) for t_name in target_name_list if t_name!=target_name]\n",
    "    dataset_val = [(x_test_scal.values,  y_test_all_scal[t_name].values.reshape(-1,1))  for t_name in target_name_list if t_name!=target_name]\n",
    "\n",
    "    # Model training\n",
    "    print(target_name+':  MAML training')\n",
    "    maml = NNModel()\n",
    "    fix_seed(i_seed)\n",
    "    maml, losses, val_losses = train_maml(\n",
    "        model       = maml,\n",
    "        epochs      = 1000, \n",
    "        dataset     = dataset, \n",
    "        dataset_val = dataset_val,\n",
    "        lr_inner    = 0.01,\n",
    "        batch_size  = 1024,\n",
    "        log_steps   = 5,\n",
    "        patience    = 3\n",
    "    )\n",
    "    \n",
    "    # Repeat for the different number of samples\n",
    "    for num_train in n_sample_list:\n",
    "        if not os.path.isdir('../30_Output/20_Plot/320_TransferLearning/'+target_name+'/n'+str(num_train)):\n",
    "            os.makedirs('../30_Output/20_Plot/320_TransferLearning/'+target_name+'/n'+str(num_train))\n",
    "\n",
    "        for n_itr in range(max_itr):\n",
    "            print(target_name+'   n : '+str(num_train)+',  try : '+str(n_itr))\n",
    "            t1 = time.time()\n",
    "            \n",
    "            maml_copied = copy_model(maml, x_train_scal.values)\n",
    "            \n",
    "            # Make training data\n",
    "            sample_id = sample_list[n_itr][:num_train]\n",
    "            x_train_tmp = x_train.iloc[sample_id,].copy()\n",
    "            y_train_tmp = y_train.iloc[sample_id,].copy()\n",
    "            x_test_tmp = x_test.copy()\n",
    "            y_test_tmp = y_test.copy()\n",
    "\n",
    "            # Scaling parameters\n",
    "            ## Inputs\n",
    "            x_mean_tmp = x_train_tmp.mean()\n",
    "            x_std_tmp = x_train_tmp.std()\n",
    "            x_train_scal_tmp = (x_train_tmp - x_mean_tmp) / x_std_tmp.replace(0,1)\n",
    "            x_test_scal_tmp = (x_test_tmp - x_mean_tmp) / x_std_tmp.replace(0,1)\n",
    "            ## Outputs\n",
    "            y_mean_tmp = y_train_tmp.mean()\n",
    "            y_std_tmp = y_train_tmp.std()\n",
    "            y_train_scal_tmp = (y_train_tmp - y_mean_tmp) / y_std_tmp\n",
    "            y_test_scal_tmp = (y_test_tmp - y_mean_tmp) / y_std_tmp\n",
    "            \n",
    "\n",
    "            fix_seed(i_seed)\n",
    "            ### Grid search\n",
    "            gsr_maml = GridSearchCV(\n",
    "                MAML_fit(),\n",
    "                {'epochs' : [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 50, 100],\n",
    "                 'learning_rate' : [0.01]},\n",
    "                scoring = 'neg_mean_squared_error',\n",
    "                cv = 5,\n",
    "                n_jobs = 1,\n",
    "                refit=False,\n",
    "                verbose = True\n",
    "            )\n",
    "            fix_seed(i_seed)\n",
    "            gsr_maml.fit(x_train_scal_tmp, y_train_scal_tmp)\n",
    "\n",
    "            target_model = MAML_fit(\n",
    "                epochs = gsr_maml.best_params_['epochs'],                  \n",
    "                learning_rate = gsr_maml.best_params_['learning_rate']\n",
    "            )\n",
    "            fix_seed(i_seed)\n",
    "            ### Final model training\n",
    "            target_model.fit(x_train_scal_tmp, y_train_scal_tmp)\n",
    "\n",
    "            y_fits_maml = target_model.predict(x_train_scal_tmp) * y_std_tmp + y_mean_tmp\n",
    "            y_pred_maml = target_model.predict(x_test_scal_tmp) * y_std_tmp + y_mean_tmp\n",
    "\n",
    "            # Plot\n",
    "            plot_scatter(\n",
    "                y_obs1_list = [y_train_tmp],\n",
    "                y_prd1_list = [y_fits_maml],\n",
    "                y_obs2_list = [y_test],\n",
    "                y_prd2_list = [y_pred_maml],\n",
    "                title_list  = ['Final - rescaled'],\n",
    "                plt_row     = 1,\n",
    "                plt_col     = 1,\n",
    "                position_list = [1],\n",
    "                col_list1   = ['steelblue'],\n",
    "                alpha_list1 = [1],\n",
    "                col_list2   = ['tomato'],\n",
    "                alpha_list2 = [1],\n",
    "                fig_size    = (5, 5),\n",
    "                save_name   = './tmp.png',\n",
    "                title       = target_name,\n",
    "                show_flg    = False\n",
    "            )\n",
    "\n",
    "            ## Dataframe\n",
    "            df_result = pd.concat([df_result,\n",
    "                                pd.DataFrame(np.array([\n",
    "                                    target_name, \n",
    "                                    num_train,\n",
    "                                    n_itr,\n",
    "                                    'MAML',\n",
    "                                    mean_squared_error(y_test, y_pred_maml),\n",
    "                                    np.corrcoef(y_test, y_pred_maml)[0,1],\n",
    "                                    mean_absolute_error(y_test, y_pred_maml),\n",
    "                                    r2_score(y_test, y_pred_maml),\n",
    "                                    gsr_maml.best_params_['epochs']\n",
    "                                 ]).reshape(1, -1), columns=['data_name','n_sample','n_itr','type', 'MSE', 'Corr', 'MAE', 'R2', 'epochs'], \n",
    "                                index=[target_name+'_n'+str(num_train)+'_itr'+str(n_itr)+'_MAML'])], axis=0)\n",
    "            df_result.to_csv('../30_Output/30_csv/320_TransferLearning_Result.csv')\n",
    "            \n",
    "            clear_output(True)\n",
    "            print(time.time()-t1, ' / ', time.time()-t0)\n",
    "clear_output(True)\n",
    "print(time.time()-t0)\n",
    "print('*** Succeeded ***')"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
