{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "from tg_solver import MIP, TAG\n",
    "from utilis import *\n",
    "task = \"ettm1\"\n",
    "all_comb_res = read_infos(task)\n",
    "S_ours = pickle.load(open(\"transfer_gains/ettm1_transfer_gain.pkl\", \"rb\"))\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Our results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "& \\multirow{2}{*}{Ours} & - & - & $0.65\\pm0.055$ & - & $0.56\\pm0.011$ & $0.27\\pm0.009$ & $0.16\\pm0.006$ & \\multirow{2}{*}{$2.989$} \\\\\n",
      "&  & $0.61\\pm0.015$ & $0.38\\pm0.012$ & - & $0.36\\pm0.014$ & $0.57\\pm0.010$ & - & - &  \\\\\n",
      "----------------------------------------------------------------------------------------------------\n",
      "& \\multirow{3}{*}{Ours} & - & - & $0.65\\pm0.055$ & - & $0.56\\pm0.011$ & $0.27\\pm0.009$ & $0.16\\pm0.006$ & \\multirow{3}{*}{$2.979$} \\\\\n",
      "&  & $0.61\\pm0.015$ & $0.38\\pm0.012$ & - & $0.36\\pm0.014$ & $0.57\\pm0.010$ & - & - &  \\\\\n",
      "&  & - & $0.37\\pm0.007$ & - & - & $0.58\\pm0.037$ & $0.27\\pm0.012$ & - &  \\\\\n",
      "----------------------------------------------------------------------------------------------------\n",
      "& \\multirow{4}{*}{Ours} & - & $0.37\\pm0.007$ & - & - & $0.58\\pm0.037$ & $0.27\\pm0.012$ & - & \\multirow{4}{*}{$2.966$} \\\\\n",
      "&  & - & - & $0.65\\pm0.055$ & - & $0.56\\pm0.011$ & $0.27\\pm0.009$ & $0.16\\pm0.006$ &  \\\\\n",
      "&  & - & - & - & - & $0.58\\pm0.016$ & $0.27\\pm0.008$ & $0.15\\pm0.005$ &  \\\\\n",
      "&  & $0.61\\pm0.015$ & $0.38\\pm0.012$ & - & $0.36\\pm0.014$ & $0.57\\pm0.010$ & - & - &  \\\\\n",
      "----------------------------------------------------------------------------------------------------\n"
     ]
    }
   ],
   "source": [
    "for split in [2,3,4]:\n",
    "    our_res, _ = MIP(S_ours, split)\n",
    "    show_res(all_comb_res, our_res, task, \"Ours\".format(split))\n",
    "    print('-'*100)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### TAG"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "& \\multirow{2}{*}{TAG} & $0.66\\pm0.041$ & $0.37\\pm0.005$ & $0.66\\pm0.022$ & $0.36\\pm0.007$ & - & - & - & \\multirow{2}{*}{$3.032$} \\\\\n",
      "&  & $0.66\\pm0.014$ & $0.38\\pm0.006$ & $0.64\\pm0.056$ & $0.35\\pm0.009$ & $0.59\\pm0.016$ & $0.28\\pm0.005$ & $0.15\\pm0.009$ &  \\\\\n",
      "----------------------------------------------------------------------------------------------------\n",
      "& \\multirow{3}{*}{TAG} & $0.64\\pm0.016$ & $0.38\\pm0.008$ & $0.65\\pm0.017$ & - & - & - & - & \\multirow{3}{*}{$3.037$} \\\\\n",
      "&  & $0.64\\pm0.026$ & - & $0.67\\pm0.025$ & $0.36\\pm0.007$ & $0.58\\pm0.021$ & - & - &  \\\\\n",
      "&  & $0.63\\pm0.030$ & - & $0.66\\pm0.037$ & $0.37\\pm0.016$ & $0.57\\pm0.022$ & $0.29\\pm0.016$ & $0.16\\pm0.001$ &  \\\\\n",
      "----------------------------------------------------------------------------------------------------\n",
      "& \\multirow{4}{*}{TAG} & $0.64\\pm0.016$ & $0.38\\pm0.008$ & $0.65\\pm0.017$ & - & - & - & - & \\multirow{4}{*}{$3.026$} \\\\\n",
      "&  & $0.63\\pm0.004$ & - & $0.63\\pm0.030$ & - & $0.58\\pm0.025$ & $0.28\\pm0.008$ & - &  \\\\\n",
      "&  & - & - & $0.68\\pm0.039$ & $0.36\\pm0.013$ & - & - & - &  \\\\\n",
      "&  & - & - & $0.70\\pm0.054$ & - & - & - & $0.16\\pm0.007$ &  \\\\\n",
      "----------------------------------------------------------------------------------------------------\n"
     ]
    }
   ],
   "source": [
    "TAG_groupings = [\n",
    "    [[0, 1, 2, 3], [0, 1, 2, 3, 4, 5, 6]],\n",
    "    [[0, 1, 2], [0, 2, 3, 4], [0, 2, 3, 4, 5, 6]],\n",
    "    [[0, 1, 2], [0, 2, 4, 5], [2, 3], [2, 6]],\n",
    "]\n",
    "for i, split in enumerate([2, 3, 4]):\n",
    "    show_res(all_comb_res, TAG_groupings[i], task, \"TAG\".format(split))\n",
    "    print(\"-\" * 100)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### OPT"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "& \\multirow{2}{*}{OPT} & - & $0.37\\pm0.007$ & - & $0.35\\pm0.003$ & $0.56\\pm0.029$ & - & $0.15\\pm0.011$ & \\multirow{2}{*}{$2.926$} \\\\\n",
      "&  & $0.61\\pm0.015$ & - & $0.62\\pm0.036$ & $0.36\\pm0.007$ & $0.57\\pm0.036$ & $0.28\\pm0.008$ & - &  \\\\\n",
      "----------------------------------------------------------------------------------------------------\n",
      "& \\multirow{3}{*}{OPT} & - & - & - & $0.35\\pm0.018$ & $0.55\\pm0.022$ & $0.27\\pm0.015$ & - & \\multirow{3}{*}{$2.913$} \\\\\n",
      "&  & - & $0.37\\pm0.007$ & - & $0.35\\pm0.003$ & $0.56\\pm0.029$ & - & $0.15\\pm0.011$ &  \\\\\n",
      "&  & $0.61\\pm0.015$ & - & $0.62\\pm0.036$ & $0.36\\pm0.007$ & $0.57\\pm0.036$ & $0.28\\pm0.008$ & - &  \\\\\n",
      "----------------------------------------------------------------------------------------------------\n",
      "& \\multirow{4}{*}{OPT} & $0.66\\pm0.028$ & - & - & - & - & - & $0.14\\pm0.005$ & \\multirow{4}{*}{$2.906$} \\\\\n",
      "&  & - & - & - & $0.35\\pm0.018$ & $0.55\\pm0.022$ & $0.27\\pm0.015$ & - &  \\\\\n",
      "&  & - & $0.37\\pm0.007$ & - & $0.35\\pm0.003$ & $0.56\\pm0.029$ & - & $0.15\\pm0.011$ &  \\\\\n",
      "&  & $0.61\\pm0.015$ & - & $0.62\\pm0.036$ & $0.36\\pm0.007$ & $0.57\\pm0.036$ & $0.28\\pm0.008$ & - &  \\\\\n",
      "----------------------------------------------------------------------------------------------------\n"
     ]
    }
   ],
   "source": [
    "OPT_groupings = [\n",
    "    [[1, 3, 4, 6], [0, 2, 3, 4, 5]],\n",
    "    [[3, 4, 5], [1, 3, 4, 6], [0, 2, 3, 4, 5]],\n",
    "    [[0, 6], [3, 4, 5], [1, 3, 4, 6], [0, 2, 3, 4, 5]],\n",
    "]\n",
    "\n",
    "for i, split in enumerate([2, 3, 4]):\n",
    "    show_res(all_comb_res, OPT_groupings[i], task, \"OPT\".format(split))\n",
    "    print(\"-\" * 100)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### STL and MTL"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "& \\multirow{7}{*}{STL} & $0.64\\pm0.016$ & - & - & - & - & - & - & \\multirow{7}{*}{$3.047$} \\\\\n",
      "&  & - & $0.37\\pm0.009$ & - & - & - & - & - &  \\\\\n",
      "&  & - & - & $0.68\\pm0.015$ & - & - & - & - &  \\\\\n",
      "&  & - & - & - & $0.36\\pm0.011$ & - & - & - &  \\\\\n",
      "&  & - & - & - & - & $0.56\\pm0.027$ & - & - &  \\\\\n",
      "&  & - & - & - & - & - & $0.29\\pm0.002$ & - &  \\\\\n",
      "&  & - & - & - & - & - & - & $0.15\\pm0.006$ &  \\\\\n",
      "----------------------------------------------------------------------------------------------------\n",
      "& \\multirow{1}{*}{MTL} & $0.66\\pm0.014$ & $0.38\\pm0.006$ & $0.64\\pm0.056$ & $0.35\\pm0.009$ & $0.59\\pm0.016$ & $0.28\\pm0.005$ & $0.15\\pm0.009$ & \\multirow{1}{*}{$3.033$} \\\\\n"
     ]
    }
   ],
   "source": [
    "show_res(all_comb_res, [[i] for i in range(7)], task, \"STL\")\n",
    "print(\"-\" * 100)\n",
    "show_res(all_comb_res, [[i for i in range(7)]], task, \"MTL\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "game",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
