{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "from tg_solver import MIP, TAG, MIP_constr,random_sampling_with_const\n",
    "from utilis import *\n",
    "task = \"taskonomy\"\n",
    "all_comb_res = read_infos(task)\n",
    "S_ours = pickle.load(open(\"transfer_gains/tky_transfer_gain.pkl\", \"rb\"))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Our results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "& \\multirow{2}{*}{Ours} & $0.138$ & $0.088$ & $-0.028$ & $-0.077$ & $0.052$ & \\multirow{2}{*}{$0.499$} \\\\\n",
      "&  & - & $0.031$ & $-0.068$ & $0.249$ & $0.032$ &  \\\\\n",
      "----------------------------------------------------------------------------------------------------\n",
      "& \\multirow{3}{*}{Ours} & - & - & $-0.101$ & $0.412$ & $0.067$ & \\multirow{3}{*}{$0.677$} \\\\\n",
      "&  & $0.138$ & $0.088$ & $-0.028$ & $-0.077$ & $0.052$ &  \\\\\n",
      "&  & - & $0.031$ & $-0.068$ & $0.249$ & $0.032$ &  \\\\\n",
      "----------------------------------------------------------------------------------------------------\n",
      "& \\multirow{4}{*}{Ours} & - & - & $-0.101$ & $0.412$ & $0.067$ & \\multirow{4}{*}{$0.688$} \\\\\n",
      "&  & $0.138$ & $0.088$ & $-0.028$ & $-0.077$ & $0.052$ &  \\\\\n",
      "&  & - & $0.031$ & $-0.068$ & $0.249$ & $0.032$ &  \\\\\n",
      "&  & $0.149$ & - & $-0.062$ & $0.062$ & $0.020$ &  \\\\\n",
      "----------------------------------------------------------------------------------------------------\n"
     ]
    }
   ],
   "source": [
    "import time\n",
    "for split in [2, 3, 4]:\n",
    "    our_res, _ = MIP(S_ours, split)\n",
    "    show_res(all_comb_res, our_res, task, \"Ours\")\n",
    "    print(\"-\" * 100)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### TAG"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "& \\multirow{2}{*}{TAG-2} & $-0.008$ & $-0.690$ & $-0.167$ & - & - & \\multirow{2}{*}{$-0.866$} \\\\\n",
      "&  & - & - & - & $-0.013$ & $0.012$ &  \\\\\n",
      "----------------------------------------------------------------------------------------------------\n",
      "& \\multirow{3}{*}{TAG-3} & $0.052$ & $0.020$ & - & - & - & \\multirow{3}{*}{$0.206$} \\\\\n",
      "&  & - & $0.146$ & $0.008$ & - & - &  \\\\\n",
      "&  & - & - & - & $-0.013$ & $0.012$ &  \\\\\n",
      "----------------------------------------------------------------------------------------------------\n",
      "& \\multirow{4}{*}{TAG-4} & $0.052$ & $0.020$ & - & - & - & \\multirow{4}{*}{$0.206$} \\\\\n",
      "&  & - & $0.146$ & $0.008$ & - & - &  \\\\\n",
      "&  & $-0.008$ & $-0.690$ & $-0.167$ & - & - &  \\\\\n",
      "&  & - & - & - & $-0.013$ & $0.012$ &  \\\\\n",
      "----------------------------------------------------------------------------------------------------\n"
     ]
    }
   ],
   "source": [
    "TAG_groupings = [\n",
    "    [[0, 1, 2], [3, 4]],\n",
    "    [[0, 1], [1, 2], [3, 4]],\n",
    "    [[0,1],[1,2], [0, 1, 2], [3, 4]],\n",
    "]\n",
    "for i, split in enumerate([2, 3,4]):\n",
    "    show_res(all_comb_res, TAG_groupings[i], task, \"TAG-{}\".format(split))\n",
    "    print(\"-\" * 100)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### CS"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "& \\multirow{2}{*}{CS} & $-0.008$ & $-0.690$ & $-0.167$ & - & - & \\multirow{2}{*}{$-0.866$} \\\\\n",
      "&  & - & - & - & $-0.013$ & $0.012$ &  \\\\\n",
      "Results for Short:\n",
      "& CS & $-0.008$ & $-0.690$ & $-0.167$ & $-0.013$ & $0.012$ &$ -0.866$\n",
      "----------------------------------------------------------------------------------------------------\n",
      "& \\multirow{3}{*}{CS} & $0.052$ & $0.020$ & - & - & - & \\multirow{3}{*}{$0.206$} \\\\\n",
      "&  & - & $0.146$ & $0.008$ & - & - &  \\\\\n",
      "&  & - & - & - & $-0.013$ & $0.012$ &  \\\\\n",
      "Results for Short:\n",
      "& CS & $0.052$ & $0.146$ & $0.008$ & $-0.013$ & $0.012$ &$ 0.206$\n",
      "----------------------------------------------------------------------------------------------------\n",
      "& \\multirow{4}{*}{CS} & $0.052$ & $0.020$ & - & - & - & \\multirow{4}{*}{$0.206$} \\\\\n",
      "&  & - & $0.146$ & $0.008$ & - & - &  \\\\\n",
      "&  & $-0.008$ & $-0.690$ & $-0.167$ & - & - &  \\\\\n",
      "&  & - & - & - & $-0.013$ & $0.012$ &  \\\\\n",
      "Results for Short:\n",
      "& CS & $0.052$ & $0.146$ & $0.008$ & $-0.013$ & $0.012$ &$ 0.206$\n",
      "----------------------------------------------------------------------------------------------------\n"
     ]
    }
   ],
   "source": [
    "CS_groupings = [\n",
    "    [[0, 1, 2], [3, 4]],\n",
    "    [[0, 1], [1, 2], [3, 4]],\n",
    "    [[0, 1], [1, 2], [0, 1, 2], [3, 4]],\n",
    "]\n",
    "for i, split in enumerate([2, 3, 4]):\n",
    "    show_res(all_comb_res, CS_groupings[i], task, \"CS\".format(split))\n",
    "    print(\"-\" * 100)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### HOA"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "& \\multirow{2}{*}{HOA} & $0.072$ & $0.014$ & $-0.002$ & - & $0.080$ & \\multirow{2}{*}{$0.608$} \\\\\n",
      "&  & - & - & $-0.148$ & $0.443$ & - &  \\\\\n",
      "----------------------------------------------------------------------------------------------------\n",
      "& \\multirow{3}{*}{HOA} & $-0.008$ & $-0.690$ & $-0.167$ & - & - & \\multirow{3}{*}{$-0.391$} \\\\\n",
      "&  & - & - & $-0.148$ & $0.443$ & - &  \\\\\n",
      "&  & - & - & - & $-0.013$ & $0.012$ &  \\\\\n",
      "----------------------------------------------------------------------------------------------------\n",
      "& \\multirow{4}{*}{HOA} & $0.199$ & - & $-0.101$ & - & - & \\multirow{4}{*}{$0.809$} \\\\\n",
      "&  & - & $0.146$ & $0.008$ & - & - &  \\\\\n",
      "&  & - & - & $-0.148$ & $0.443$ & - &  \\\\\n",
      "&  & - & - & - & $-0.013$ & $0.012$ &  \\\\\n",
      "----------------------------------------------------------------------------------------------------\n"
     ]
    }
   ],
   "source": [
    "HOA_groupings = [\n",
    "    [[0, 1, 2, 4], [2, 3]],\n",
    "    [[0, 1, 2], [2, 3], [3, 4]],\n",
    "    [[0, 2], [1, 2], [2, 3], [3, 4]],\n",
    "]\n",
    "for i, split in enumerate([2, 3,4]):\n",
    "    show_res(all_comb_res, HOA_groupings[i], task, \"HOA\".format(split))\n",
    "    print(\"-\" * 100)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### OPT in TAG"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "& \\multirow{2}{*}{OPT-2} & $0.082$ & $0.044$ & $-0.072$ & $0.318$ & - & \\multirow{2}{*}{$0.532$}\n",
      "\\\\\n",
      "&  & - & - & $-0.101$ & $0.412$ & $0.067$ & \n",
      "\\\\\n",
      "Results for Short:\n",
      "& OPT-2 & $0.082$ & $0.044$ & $-0.072$ & $0.412$ & $0.067$ &$ 0.532$\n",
      "----------------------------------------------------------------------------------------------------\n",
      "& \\multirow{3}{*}{OPT-3} & $-0.008$ & $-0.690$ & $-0.167$ & - & - & \\multirow{3}{*}{$-0.520$}\n",
      "\\\\\n",
      "&  & $0.127$ & - & $-0.045$ & $-0.014$ & - & \n",
      "\\\\\n",
      "&  & - & - & $0.017$ & - & $0.041$ & \n",
      "\\\\\n",
      "Results for Short:\n",
      "& OPT-3 & $0.127$ & $-0.690$ & $0.017$ & $-0.014$ & $0.041$ &$ -0.520$\n",
      "----------------------------------------------------------------------------------------------------\n"
     ]
    }
   ],
   "source": [
    "OPT_groupings = [\n",
    "    [[0, 1, 2, 3], [2,3, 4]],\n",
    "    [[0, 1, 2], [0, 2, 3], [2, 4]],\n",
    "]\n",
    "for i, split in enumerate([2, 3]):\n",
    "    show_res(all_comb_res, OPT_groupings[i], task, \"OPT-{}\".format(split))\n",
    "    print(\"-\" * 100)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### True OPT"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "& \\multirow{2}{*}{OPT-2} & - & - & $-0.148$ & $0.443$ & - & \\multirow{2}{*}{$0.694$} \\\\\n",
      "&  & $0.138$ & $0.088$ & $-0.028$ & $-0.077$ & $0.052$ &  \\\\\n",
      "----------------------------------------------------------------------------------------------------\n",
      "& \\multirow{3}{*}{OPT-3} & $0.199$ & - & $-0.101$ & - & - & \\multirow{3}{*}{$0.888$} \\\\\n",
      "&  & - & - & $-0.148$ & $0.443$ & - &  \\\\\n",
      "&  & - & $0.117$ & $0.018$ & - & $0.110$ &  \\\\\n",
      "----------------------------------------------------------------------------------------------------\n",
      "& \\multirow{4}{*}{OPT-4} & - & - & $-0.000$ & - & - & \\multirow{4}{*}{$0.833$} \\\\\n",
      "&  & $0.199$ & - & $-0.101$ & - & - &  \\\\\n",
      "&  & - & $0.146$ & $0.008$ & - & - &  \\\\\n",
      "&  & - & - & $-0.101$ & $0.412$ & $0.067$ &  \\\\\n",
      "----------------------------------------------------------------------------------------------------\n"
     ]
    }
   ],
   "source": [
    "OPT_groupings = [\n",
    "    [[2, 3], [0, 1, 2, 3, 4]],\n",
    "    [[0, 2], [2, 3], [1, 2, 4]],\n",
    "    [[2], [0, 2], [1, 2], [2, 3, 4]],\n",
    "]\n",
    "for i, split in enumerate([2, 3, 4]):\n",
    "    show_res(all_comb_res, OPT_groupings[i], task, \"OPT-{}\".format(split))\n",
    "    print(\"-\" * 100)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### STL and MTL"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "& \\multirow{5}{*}{STL} & $-0.000$ & - & - & - & - & \\multirow{5}{*}{$-0.000$} \\\\\n",
      "&  & - & $-0.000$ & - & - & - &  \\\\\n",
      "&  & - & - & $-0.000$ & - & - &  \\\\\n",
      "&  & - & - & - & $-0.000$ & - &  \\\\\n",
      "&  & - & - & - & - & $-0.000$ &  \\\\\n",
      "----------------------------------------------------------------------------------------------------\n",
      "& \\multirow{1}{*}{MTL} & $0.138$ & $0.088$ & $-0.028$ & $-0.077$ & $0.052$ & \\multirow{1}{*}{$0.173$} \\\\\n"
     ]
    }
   ],
   "source": [
    "show_res(all_comb_res, [[i] for i in range(5)], task, \"STL\")\n",
    "print(\"-\" * 100)\n",
    "show_res(all_comb_res, [[i for i in range(5)]], task, \"MTL\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "game",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
