{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "from tg_solver import MIP, TAG, MIP_constr, random_sampling_with_const\n",
    "from utilis import *\n",
    "import pickle\n",
    "task = \"celeba\"\n",
    "all_comb_res = read_infos(task)\n",
    "S_ours = pickle.load(open(\"transfer_gains/celeba_transfer_gain.pkl\", \"rb\"))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Our results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "& \\multirow{2}{*}{Ours} & $6.55\\pm0.016$ & $11.09\\pm0.023$ & $4.19\\pm0.014$ & $12.56\\pm0.101$ & $2.58\\pm0.015$ & $3.02\\pm0.027$ & $4.80\\pm0.017$ & $4.74\\pm0.034$ & $0.70\\pm0.004$ & \\multirow{2}{*}{$49.534$} \\\\\n",
      "&  & $6.60\\pm0.017$ & $11.22\\pm0.042$ & $4.36\\pm0.018$ & $12.05\\pm0.133$ & $2.60\\pm0.010$ & $2.88\\pm0.011$ & $4.82\\pm0.037$ & $4.71\\pm0.037$ & - &  \\\\\n",
      "----------------------------------------------------------------------------------------------------\n",
      "& \\multirow{3}{*}{Ours} & $6.60\\pm0.017$ & $11.22\\pm0.042$ & $4.36\\pm0.018$ & $12.05\\pm0.133$ & $2.60\\pm0.010$ & $2.88\\pm0.011$ & $4.82\\pm0.037$ & $4.71\\pm0.037$ & - & \\multirow{3}{*}{$49.348$} \\\\\n",
      "&  & $6.55\\pm0.016$ & $11.09\\pm0.023$ & $4.19\\pm0.014$ & $12.56\\pm0.101$ & $2.58\\pm0.015$ & $3.02\\pm0.027$ & $4.80\\pm0.017$ & $4.74\\pm0.034$ & $0.70\\pm0.004$ &  \\\\\n",
      "&  & $6.62\\pm0.043$ & $11.32\\pm0.131$ & $4.09\\pm0.015$ & $12.21\\pm0.051$ & $2.60\\pm0.044$ & $2.93\\pm0.008$ & $4.72\\pm0.053$ & - & - &  \\\\\n",
      "----------------------------------------------------------------------------------------------------\n",
      "& \\multirow{4}{*}{Ours} & $6.62\\pm0.043$ & $11.32\\pm0.131$ & $4.09\\pm0.015$ & $12.21\\pm0.051$ & $2.60\\pm0.044$ & $2.93\\pm0.008$ & $4.72\\pm0.053$ & - & - & \\multirow{4}{*}{$49.335$} \\\\\n",
      "&  & $6.63\\pm0.026$ & $11.36\\pm0.083$ & $4.34\\pm0.034$ & $12.23\\pm0.057$ & $2.77\\pm0.019$ & - & $4.94\\pm0.009$ & $4.69\\pm0.027$ & - &  \\\\\n",
      "&  & $6.55\\pm0.016$ & $11.09\\pm0.023$ & $4.19\\pm0.014$ & $12.56\\pm0.101$ & $2.58\\pm0.015$ & $3.02\\pm0.027$ & $4.80\\pm0.017$ & $4.74\\pm0.034$ & $0.70\\pm0.004$ &  \\\\\n",
      "&  & $6.60\\pm0.017$ & $11.22\\pm0.042$ & $4.36\\pm0.018$ & $12.05\\pm0.133$ & $2.60\\pm0.010$ & $2.88\\pm0.011$ & $4.82\\pm0.037$ & $4.71\\pm0.037$ & - &  \\\\\n",
      "----------------------------------------------------------------------------------------------------\n"
     ]
    }
   ],
   "source": [
    "for split in [2,3,4]:\n",
    "    our_res, _ = MIP(S_ours, split)\n",
    "    show_res(all_comb_res, our_res, task, \"Ours\")\n",
    "    print('-'*100)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### TAG"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "& \\multirow{2}{*}{TAG} & $6.50\\pm0.158$ & - & - & - & - & - & $4.88\\pm0.118$ & - & - & \\multirow{2}{*}{$51.036$} \\\\\n",
      "&  & - & $11.16\\pm0.053$ & $4.24\\pm0.038$ & $13.15\\pm0.112$ & $2.66\\pm0.004$ & $2.99\\pm0.006$ & - & $4.74\\pm0.010$ & $0.71\\pm0.005$ &  \\\\\n",
      "----------------------------------------------------------------------------------------------------\n",
      "& \\multirow{3}{*}{TAG} & $6.50\\pm0.158$ & - & - & - & - & - & $4.88\\pm0.118$ & - & - & \\multirow{3}{*}{$49.897$} \\\\\n",
      "&  & - & $11.14\\pm0.047$ & - & $12.31\\pm0.030$ & - & - & - & - & - &  \\\\\n",
      "&  & - & $11.18\\pm0.160$ & $4.11\\pm0.007$ & - & $2.55\\pm0.029$ & $2.92\\pm0.063$ & $4.95\\pm0.117$ & $4.73\\pm0.013$ & $0.76\\pm0.014$ &  \\\\\n",
      "----------------------------------------------------------------------------------------------------\n",
      "& \\multirow{4}{*}{TAG} & $6.50\\pm0.158$ & - & - & - & - & - & $4.88\\pm0.118$ & - & - & \\multirow{4}{*}{$49.605$} \\\\\n",
      "&  & - & $11.14\\pm0.047$ & - & $12.31\\pm0.030$ & - & - & - & - & - &  \\\\\n",
      "&  & - & - & - & - & $2.62\\pm0.010$ & $2.99\\pm0.004$ & $4.76\\pm0.009$ & - & - &  \\\\\n",
      "&  & - & $10.97\\pm0.023$ & $4.09\\pm0.033$ & $12.40\\pm0.054$ & - & - & $4.95\\pm0.017$ & $4.63\\pm0.012$ & $0.72\\pm0.002$ &  \\\\\n",
      "----------------------------------------------------------------------------------------------------\n"
     ]
    }
   ],
   "source": [
    "TAG_groupings = [\n",
    "    [[0, 6], [1, 2, 3, 4, 5, 7, 8]],\n",
    "    [[0, 6], [1, 3], [1, 2, 4, 5, 6, 7, 8]],\n",
    "    [[0, 6], [1, 3], [4, 5, 6], [1, 2, 3, 6, 7, 8]],\n",
    "]\n",
    "for i, split in enumerate([2, 3, 4]):\n",
    "    show_res(all_comb_res, TAG_groupings[i], task, \"TAG\".format(split))\n",
    "    print(\"-\" * 100)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### CS"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "& \\multirow{2}{*}{CS} & - & - & - & - & $2.59\\pm0.026$ & $3.06\\pm0.017$ & - & - & - & \\multirow{2}{*}{$50.278$} \\\\\n",
      "&  & $6.58\\pm0.028$ & $11.15\\pm0.027$ & $4.26\\pm0.097$ & $12.35\\pm0.229$ & $2.54\\pm0.029$ & - & $4.86\\pm0.059$ & $4.77\\pm0.083$ & $0.71\\pm0.017$ &  \\\\\n",
      "----------------------------------------------------------------------------------------------------\n",
      "& \\multirow{3}{*}{CS} & - & - & - & - & $2.59\\pm0.026$ & $3.06\\pm0.017$ & - & - & - & \\multirow{3}{*}{$50.954$} \\\\\n",
      "&  & $6.50\\pm0.158$ & - & - & - & - & - & $4.88\\pm0.118$ & - & - &  \\\\\n",
      "&  & - & $11.18\\pm0.008$ & $4.10\\pm0.013$ & $13.06\\pm0.105$ & - & - & - & $4.83\\pm0.013$ & $0.74\\pm0.004$ &  \\\\\n",
      "----------------------------------------------------------------------------------------------------\n",
      "& \\multirow{4}{*}{CS} & - & - & - & - & $2.59\\pm0.026$ & $3.06\\pm0.017$ & - & - & - & \\multirow{4}{*}{$49.753$} \\\\\n",
      "&  & $6.50\\pm0.158$ & - & - & - & - & - & $4.88\\pm0.118$ & - & - &  \\\\\n",
      "&  & - & $11.14\\pm0.047$ & - & $12.31\\pm0.030$ & - & - & - & - & - &  \\\\\n",
      "&  & - & $10.97\\pm0.023$ & $4.09\\pm0.033$ & $12.40\\pm0.054$ & - & - & $4.95\\pm0.017$ & $4.63\\pm0.012$ & $0.72\\pm0.002$ &  \\\\\n",
      "----------------------------------------------------------------------------------------------------\n"
     ]
    }
   ],
   "source": [
    "CS_groupings = [\n",
    "    [[4, 5], [0, 1, 2, 3, 4, 6, 7, 8]],\n",
    "    [[4, 5], [0, 6], [1, 2, 3, 7, 8]],\n",
    "    [[4, 5], [0, 6], [1, 3], [1, 2, 3, 6, 7, 8]],\n",
    "]\n",
    "for i, split in enumerate([2, 3, 4]):\n",
    "    show_res(all_comb_res, CS_groupings[i], task, \"CS\".format(split))\n",
    "    print(\"-\" * 100)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### HOA"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "& \\multirow{2}{*}{HOA} & - & - & - & - & $2.57\\pm0.019$ & - & $4.69\\pm0.024$ & - & - & \\multirow{2}{*}{$49.600$} \\\\\n",
      "&  & $6.55\\pm0.018$ & $11.27\\pm0.026$ & $4.14\\pm0.009$ & $11.87\\pm0.045$ & - & $3.03\\pm0.016$ & - & $4.81\\pm0.025$ & $0.67\\pm0.006$ &  \\\\\n",
      "----------------------------------------------------------------------------------------------------\n",
      "& \\multirow{3}{*}{HOA} & - & - & - & - & $2.57\\pm0.019$ & - & $4.69\\pm0.024$ & - & - & \\multirow{3}{*}{$49.716$} \\\\\n",
      "&  & - & $11.14\\pm0.047$ & - & $12.31\\pm0.030$ & - & - & - & - & - &  \\\\\n",
      "&  & $6.55\\pm0.024$ & - & $4.12\\pm0.023$ & - & $2.70\\pm0.006$ & $2.84\\pm0.002$ & $4.89\\pm0.021$ & $4.76\\pm0.011$ & $0.73\\pm0.012$ &  \\\\\n",
      "----------------------------------------------------------------------------------------------------\n",
      "& \\multirow{4}{*}{HOA} & - & - & - & - & $2.57\\pm0.019$ & - & $4.69\\pm0.024$ & - & - & \\multirow{4}{*}{$49.853$} \\\\\n",
      "&  & - & $11.14\\pm0.047$ & - & $12.31\\pm0.030$ & - & - & - & - & - &  \\\\\n",
      "&  & $6.65\\pm0.033$ & - & - & - & $2.68\\pm0.031$ & - & - & - & - &  \\\\\n",
      "&  & - & $11.18\\pm0.160$ & $4.11\\pm0.007$ & - & $2.55\\pm0.029$ & $2.92\\pm0.063$ & $4.95\\pm0.117$ & $4.73\\pm0.013$ & $0.76\\pm0.014$ &  \\\\\n",
      "----------------------------------------------------------------------------------------------------\n"
     ]
    }
   ],
   "source": [
    "HOA_groupings = [\n",
    "    [[4, 6], [0, 1, 2, 3, 5, 7, 8]],\n",
    "    [[4, 6], [1, 3], [0, 2, 4, 5, 6, 7, 8]],\n",
    "    [[4, 6], [1, 3], [0, 4], [1, 2, 4, 5, 6, 7, 8]],]\n",
    "for i, split in enumerate([2, 3, 4]):\n",
    "    show_res(all_comb_res, HOA_groupings[i], task, \"HOA\".format(split))\n",
    "    print(\"-\" * 100)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### OPT in TAG paper"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "& \\multirow{2}{*}{OPT} & - & $11.16\\pm0.002$ & $4.04\\pm0.028$ & $12.24\\pm0.367$ & - & - & - & - & $0.74\\pm0.025$ & \\multirow{2}{*}{$49.583$} \\\\\n",
      "&  & $6.60\\pm0.017$ & $11.22\\pm0.042$ & $4.36\\pm0.018$ & $12.05\\pm0.133$ & $2.60\\pm0.010$ & $2.88\\pm0.011$ & $4.82\\pm0.037$ & $4.71\\pm0.037$ & - &  \\\\\n",
      "----------------------------------------------------------------------------------------------------\n",
      "& \\multirow{3}{*}{OPT} & $6.30\\pm0.036$ & - & - & - & - & - & $4.67\\pm0.013$ & - & $0.71\\pm0.009$ & \\multirow{3}{*}{$49.323$} \\\\\n",
      "&  & - & $11.16\\pm0.002$ & $4.04\\pm0.028$ & $12.24\\pm0.367$ & - & - & - & - & $0.74\\pm0.025$ &  \\\\\n",
      "&  & $6.27\\pm0.027$ & - & - & - & $2.59\\pm0.006$ & $2.92\\pm0.011$ & $4.73\\pm0.016$ & $4.72\\pm0.019$ & $0.76\\pm0.002$ &  \\\\\n",
      "----------------------------------------------------------------------------------------------------\n",
      "& \\multirow{4}{*}{OPT} & - & - & $4.08\\pm0.079$ & - & - & $2.94\\pm0.021$ & $4.82\\pm0.033$ & - & $0.74\\pm0.006$ & \\multirow{4}{*}{$49.169$} \\\\\n",
      "&  & - & $11.29\\pm0.129$ & $4.12\\pm0.056$ & $12.59\\pm0.090$ & - & - & - & - & - &  \\\\\n",
      "&  & $6.30\\pm0.036$ & - & - & - & - & - & $4.67\\pm0.013$ & - & $0.71\\pm0.009$ &  \\\\\n",
      "&  & - & $11.11\\pm0.020$ & - & $11.92\\pm0.012$ & $2.68\\pm0.019$ & - & - & $4.77\\pm0.006$ & $0.78\\pm0.004$ &  \\\\\n",
      "----------------------------------------------------------------------------------------------------\n"
     ]
    }
   ],
   "source": [
    "OPT_groupings = [\n",
    "    [[1, 2, 3, 8], [0, 1, 2, 3, 4, 5, 6, 7]],\n",
    "    [[0, 6, 8], [1, 2, 3, 8], [0, 4, 5, 6, 7, 8]],\n",
    "    [[2, 5, 6, 8], [1, 2, 3], [0, 6, 8], [1, 3, 4, 7, 8]],\n",
    "]\n",
    "for i, split in enumerate([2, 3, 4]):\n",
    "    show_res(all_comb_res, OPT_groupings[i], task, \"OPT\".format(split))\n",
    "    print(\"-\" * 100)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### STL and MTL"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "& \\multirow{9}{*}{STL} & $6.47\\pm0.044$ & - & - & - & - & - & - & - & - & \\multirow{9}{*}{$50.622$} \\\\\n",
      "&  & - & $11.27\\pm0.037$ & - & - & - & - & - & - & - &  \\\\\n",
      "&  & - & - & $4.19\\pm0.006$ & - & - & - & - & - & - &  \\\\\n",
      "&  & - & - & - & $12.29\\pm0.020$ & - & - & - & - & - &  \\\\\n",
      "&  & - & - & - & - & $2.72\\pm0.038$ & - & - & - & - &  \\\\\n",
      "&  & - & - & - & - & - & $3.12\\pm0.015$ & - & - & - &  \\\\\n",
      "&  & - & - & - & - & - & - & $4.98\\pm0.031$ & - & - &  \\\\\n",
      "&  & - & - & - & - & - & - & - & $4.85\\pm0.019$ & - &  \\\\\n",
      "&  & - & - & - & - & - & - & - & - & $0.73\\pm0.007$ &  \\\\\n",
      "----------------------------------------------------------------------------------------------------\n",
      "& \\multirow{1}{*}{MTL} & $6.55\\pm0.016$ & $11.09\\pm0.023$ & $4.19\\pm0.014$ & $12.56\\pm0.101$ & $2.58\\pm0.015$ & $3.02\\pm0.027$ & $4.80\\pm0.017$ & $4.74\\pm0.034$ & $0.70\\pm0.004$ & \\multirow{1}{*}{$50.227$} \\\\\n"
     ]
    }
   ],
   "source": [
    "show_res(all_comb_res, [[i] for i in range(9)], task, \"STL\")\n",
    "print(\"-\" * 100)\n",
    "show_res(all_comb_res, [[i for i in range(9)]], task, \"MTL\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "game",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
