{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import math\n",
    "import random\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "import datetime\n",
    "import time\n",
    "import datetime\n",
    "import copy\n",
    "\n",
    "from src.utility import *\n",
    "from src.MTH.MultiTaskHypernetwork import *\n",
    "from src.MTL.MultiTaskLearning import *"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Run Multi-Task Hypernetwork experiment"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "##### Save results into DataFrame #####\n",
    "def save_df(results_df):\n",
    "    data_names = data_params[\"data_name\"] if isinstance(data_params[\"data_name\"], list) else [data_params[\"data_name\"]]\n",
    "    for data_name in data_names:\n",
    "        tune_cols = [col for col in param_loader.iter_param_keys.copy() if col not in [\"data_name\", \"lr\", \"seed\"]]\n",
    "        seeds = [int(results_df[\"seed\"].min()), int(results_df[\"seed\"].max())]\n",
    "\n",
    "        # Create new csv\n",
    "        filename = r\"{}_{}___{}___seeds{:d}-{:d}.{}.csv\".format(\n",
    "                params[\"algo_name\"], data_name, t0.strftime(\"%y.%m.%d_%H.%M.%S\"), seeds[0], seeds[1], \".\".join(tune_cols)\n",
    "            )\n",
    "        results_df[results_df[\"data_name\"] == data_name].to_csv(r\"./results/\" + filename, index=False)\n",
    "        print(filename)\n",
    "\n",
    "        # Delete old csv\n",
    "        if seeds[1] > seeds[0]:\n",
    "            filename = r\"{}_{}___{}___seeds{:d}-{:d}.{}.csv\".format(\n",
    "                    params[\"algo_name\"], data_name, t0.strftime(\"%y.%m.%d_%H.%M.%S\"), seeds[0], seeds[1]-1, \".\".join(tune_cols)\n",
    "                )\n",
    "            os.remove(r\"./results/\" + filename)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "##### Data parameters  #####\n",
    "data_params = {}\n",
    "data_params[\"data_name\"] = [\"Cubic\"]\n",
    "data_params[\"num_tasks\"] = 20\n",
    "data_params[\"num_train_per_task\"] = 10\n",
    "\n",
    "data_tmp = get_data(data_params[\"data_name\"][0])\n",
    "input_dim = data_tmp[\"X_train\"].shape[1]\n",
    "output_dim = data_tmp[\"y_train\"].shape[1]\n",
    "del data_tmp\n",
    "\n",
    "##### Model parameters c: #####\n",
    "params = {}\n",
    "\n",
    "params[\"algo_name\"] = \"HyperNet\"\n",
    "\n",
    "params[\"lr\"] = [10**p for p in [-3, -3.5, -4, -4.5]]   \n",
    "params[\"seed\"] = [*range(0, 50)]              \n",
    "\n",
    "params[\"use_metadata\"] = [True, False]\n",
    "\n",
    "# Hypernetwork architecture parameters\n",
    "params[\"embedding_dim\"] = [10]\n",
    "params[\"hidden_dim\"] = [32]\n",
    "params[\"hyper_extractor_layers\"] = [0, 1, 2]\n",
    "\n",
    "# Target network architecture\n",
    "params[\"target_arch\"] = [[\n",
    "    {\"type\": \"linear\", \"params\": [input_dim, 32]},\n",
    "    {\"type\": \"relu\"},\n",
    "    {\"type\": \"linear\", \"params\": [32, 32]},\n",
    "    {\"type\": \"relu\"},\n",
    "    {\"type\": \"linear\", \"params\": [32, 32]},\n",
    "    {\"type\": \"relu\"},\n",
    "    {\"type\": \"linear\", \"params\": [32, output_dim]}    \n",
    "]]\n",
    "\n",
    "params[\"verbose\"] = True\n",
    "\n",
    "\n",
    "##### Run experiments\n",
    "t0 = datetime.datetime.now()\n",
    "results_df = pd.DataFrame()\n",
    "\n",
    "param_loader = param_iterator(params, data_params)\n",
    "for i in range(param_loader.num_combinations):\n",
    "\n",
    "    # Get next param combination\n",
    "    p, d_p = param_loader.next()\n",
    "\n",
    "    # Train model\n",
    "    model = train_HyperNet(p, d_p)\n",
    "    \n",
    "    # Record results\n",
    "    if hasattr(model, \"metrics\"):\n",
    "        results = model.metrics.copy()\n",
    "        for key in d_p:\n",
    "            results[key] = d_p[key]\n",
    "        results_df = results_df.append(results, ignore_index=True)\n",
    "\n",
    "    # Save results dataframe\n",
    "    if (i+1) % (param_loader.num_combinations / len(params[\"seed\"])) == 0:\n",
    "        save_df(results_df)\n",
    "    "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Display results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "# Read results DataFrame\n",
    "filename = \"HyperNet_Cubic___23.01.26_12.34.44___seeds0-49.use_metadata.hyper_extractor_layers.csv\"\n",
    "results_df_cp = pd.read_csv(\"./results//\" + filename)\n",
    "\n",
    "# Variables of interest\n",
    "cols_to_agg = [col for col in results_df_cp.columns if col.startswith(\"acc_\")]\n",
    "tune_cols = filename.split(\"___\")[-1].split(\".\")[1:-1]\n",
    "groupby_cols = tune_cols + [\"lr\"]\n",
    "seeds = filename.split(\"___\")[-1].split(\".\")[0][5:].split(\"-\")\n",
    "num_seeds = int(seeds[1]) - int(seeds[0]) + 1\n",
    "sf = 3\n",
    "\n",
    "# Aggregate results\n",
    "results_df_cp[[c + \"_mean\" for c in cols_to_agg]] = results_df_cp.groupby(groupby_cols)[cols_to_agg].transform(\"mean\")\n",
    "results_df_cp[[c + \"_se\" for c in cols_to_agg]] = (results_df_cp.groupby(groupby_cols)[cols_to_agg].transform(\"std\")/np.sqrt(num_seeds-1))\n",
    "\n",
    "# Display DataFrame\n",
    "display_cols = [col + agg for col in [\"acc_test\", \"acc_vali\", \"acc_train\"] for agg in [\"_mean\", \"_se\"]]\n",
    "display_cols += [\"num_epochs\"]\n",
    "display_df = results_df_cp.groupby(groupby_cols).mean()\n",
    "\n",
    "print(filename)\n",
    "optimal_df = display_df[display_df['acc_vali_mean'] == display_df.groupby(tune_cols)['acc_vali_mean'].transform(\"min\")]\n",
    "display(optimal_df.sort_values(\"acc_vali_mean\").round(sf)[display_cols])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Run Multi-Task Learning Baseline experiments"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    " # Data parameters\n",
    "data_params = {}\n",
    "data_params[\"data_name\"] = [\"Algorithms\"]\n",
    "data_params[\"STL_metadata\"] = [True,False]\n",
    "data_params[\"STL_onehot\"] = [False]\n",
    "data_params[\"num_tasks\"] = \"all\"\n",
    "data_params[\"num_train_per_task\"] = \"all\"\n",
    "\n",
    "# Model parameters\n",
    "params = {}   \n",
    "\n",
    "# CHANGE algo_name to set algorithm: \"STL\" \"MTL\" \"MRN\" TF\" \"CS\" \"Sluice\"\n",
    "params[\"algo_name\"] = \"TF\"\n",
    "params[\"lr\"] = [10**p for p in [-3, -3.5, -4, -4.5]]\n",
    "params[\"vali_epoch_freq\"] = 5\n",
    "params[\"vali_epoch_delay\"] = 20\n",
    "\n",
    "params[\"batch_size\"] = 64\n",
    "params[\"batch_shuffle\"] = True\n",
    "params[\"normalise_task_training\"] = False\n",
    "params[\"max_epochs\"] = 2000                         \n",
    "\n",
    "# Network architecture\n",
    "out_dim = get_data(data_params[\"data_name\"][0])[\"y_train\"].shape[1]\n",
    "if params[\"algo_name\"] == \"STL\":\n",
    "    params[\"arch\"] = [[[32,32,32,out_dim],[]]]\n",
    "if params[\"algo_name\"] == \"MTL\":\n",
    "    params[\"arch\"] = [[[32,32,32],[out_dim]]]\n",
    "if params[\"algo_name\"] == \"MRN\":\n",
    "    params[\"arch\"] = [[[32,32,32],[out_dim]]]\n",
    "if params[\"algo_name\"] == \"TF\":\n",
    "    params[\"arch\"] = [[[],[32,32,32,out_dim]]]\n",
    "if params[\"algo_name\"] == \"CS\":\n",
    "    params[\"arch\"] = [[[],[32,32,32,out_dim]]]\n",
    "    \n",
    "# Determine which method is being used, from algo name variable\n",
    "params[\"is_MRN\"] = params[\"algo_name\"] == \"MRN\"\n",
    "params[\"is_TF\"] = params[\"algo_name\"] == \"TF\"\n",
    "params[\"is_MRN\"] = params[\"algo_name\"] == \"CS\"\n",
    "params[\"is_MRN\"] = params[\"algo_name\"] == \"Sluice\"\n",
    "\n",
    "# Specific parameters - Multilinear Network\n",
    "if params[\"is_MRN\"] == True:\n",
    "    params[\"MRN_weight\"] = [10**p for p in [-2, -3, -4, -5]] + [0]\n",
    "    params[\"MRN_feat_k\"] = [1, 0.1, 0.01]\n",
    "\n",
    "# Specific parameters - Tensor Factorisation method (DMTRL)\n",
    "if params[\"is_TF\"] == True:\n",
    "    params[\"TF_method\"] = [\"Tucker\", \"TT\"]\n",
    "    params[\"TF_k\"] = [2,4,8,16]\n",
    "    \n",
    "# Specific parameters - Cross-stitch Network, Sluice networks respectively\n",
    "params[\"is_CS\"] = False\n",
    "params[\"is_sluice\"] = False\n",
    "if params[\"is_sluice\"] == True:\n",
    "    params[\"sluice_num_subspaces\"] = 2\n",
    "    params[\"sluice_alpha_init\"] = [\"imbalanced\", \"balanced\"]\n",
    "    params[\"sluice_beta_init\"] = \"imbalanced\"\n",
    "    params[\"sluice_orthogonal_loss_coef\"] = [1e-5, 1e-4, 1e-3, 1e-2, 1e-1, 1e0]\n",
    "\n",
    "params[\"is_classification\"] = False\n",
    "params[\"loss_func_name\"] = \"mse\"\n",
    "params[\"record_task_sizes\"] = True\n",
    "params[\"verbose\"] = False\n",
    "params[\"plot\"] = False\n",
    "\n",
    "params[\"seed\"] = [*range(0,50)]\n",
    "\n",
    "t = time.time()\n",
    "\n",
    "results_df = pd.DataFrame()\n",
    "param_loader = param_iterator(data_params, params)\n",
    "for i in range(param_loader.num_combinations):\n",
    "    d_p, p = param_loader.next()\n",
    "\n",
    "    MTL = train_MTL(p, d_p)\n",
    "\n",
    "    results = MTL.metrics.copy()\n",
    "    for key in d_p:\n",
    "        results[key] = d_p[key]\n",
    "    results_df = results_df.append(results, ignore_index=True)\n",
    "    \n",
    "    # Save results dataframe\n",
    "    if (i+1) % (param_loader.num_combinations / len(params[\"seed\"])) == 0:\n",
    "        save_df(results_df)\n",
    "        \n",
    "print(time.time() - t)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
