{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "## Plotting"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "%reload_ext autoreload\n",
    "%autoreload 2\n",
    "%matplotlib inline\n",
    "\n",
    "import os\n",
    "import sys\n",
    "from omegaconf import OmegaConf\n",
    "sys.path.insert(0, \"..\")\n",
    "from floral.utils.plotting import (\n",
    "    OUTPUT_DIR, PLOTS_DIR,\n",
    "    load_runs,\n",
    "    histories_to_df,\n",
    "    setup_experiment_plotting_and_variables,\n",
    "    variables_metrics_to_csv,\n",
    ")\n",
    "\n",
    "HISTORIES = load_runs(output_dir=os.path.join(\"..\", OUTPUT_DIR))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "# ========== CHOOSE EXPERIMENT ========== #\n",
    "EXPERIMENTS = [\n",
    "    # \"run_methods_synthetic_linear\",\n",
    "    # \"run_methods_synthetic_mlp\",\n",
    "    \"run_methods_mnist_rotate\",\n",
    "    \"run_methods_mnist_label_shift\",\n",
    "    \"run_methods_cifar10_rotate\",\n",
    "    \"run_methods_cifar10_label_shift\",\n",
    "    \"run_methods_cifar100\",\n",
    "    \"run_methods_mnist_rotate_reduced\",\n",
    "    \"run_methods_mnist_label_shift_reduced\",\n",
    "    \"run_methods_cifar10_rotate_reduced\",\n",
    "    \"run_methods_cifar10_label_shift_reduced\",\n",
    "    \"run_methods_cifar100_reduced\",\n",
    "    # \"run_methods_shakespeare\",    # XXX\n",
    "    # \"run_methods_emnist\",  # XXX\n",
    "    # \"run_methods_stackoverflow\",  # XXX\n",
    "]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "run_methods_mnist_rotate...\n",
      "Ok\n",
      "\n",
      "run_methods_mnist_label_shift...\n",
      "Ok\n",
      "\n",
      "run_methods_cifar10_rotate...\n",
      "Ok\n",
      "\n",
      "run_methods_cifar10_label_shift...\n",
      "Ok\n",
      "\n",
      "run_methods_cifar100...\n",
      "Ok\n",
      "\n",
      "run_methods_mnist_rotate_reduced...\n",
      "Ok\n",
      "\n",
      "run_methods_mnist_label_shift_reduced...\n",
      "Ok\n",
      "\n",
      "run_methods_cifar10_rotate_reduced...\n",
      "Ok\n",
      "\n",
      "run_methods_cifar10_label_shift_reduced...\n",
      "Ok\n",
      "\n",
      "run_methods_cifar100_reduced...\n",
      "Ok\n"
     ]
    }
   ],
   "source": [
    "summary_dfs = {}\n",
    "for experiment in EXPERIMENTS:\n",
    "    filter_values = f\"\"\"\n",
    "    experiment: [{experiment}]\n",
    "    \"\"\"\n",
    "    ignore_values = \"\"\"\n",
    "    \"\"\"\n",
    "    print(\"\\n\" + experiment + \"...\")\n",
    "    history_df = histories_to_df(\n",
    "        HISTORIES,\n",
    "        filter_values=OmegaConf.create(filter_values),\n",
    "        ignore_values=OmegaConf.create(ignore_values),\n",
    "        #  downsampled_len=500,\n",
    "        hide_na=True,\n",
    "    )\n",
    "    results_dir = os.path.join(\"..\", PLOTS_DIR, f\"{experiment}\")\n",
    "    os.makedirs(results_dir, exist_ok=True)\n",
    "    if len(history_df) == 0:\n",
    "        print(\"Failed to find valid runs\")\n",
    "        continue\n",
    "    history_df, plot_opts, variables = setup_experiment_plotting_and_variables(history_df, experiment)\n",
    "    df = variables_metrics_to_csv(history_df, variables, results_dir)\n",
    "    summary_dfs[experiment] = df\n",
    "    print(\"Ok\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "FedAvg        & \\xmark & 91.5 {\\tiny 0.6}          & 25.8 {\\tiny 2.4}          & 78.2 {\\tiny 0.6}          & 23.2 {\\tiny 0.9}          & 64.4 {\\tiny 0.3}          & 21.9 {\\tiny 0.4}          & 45.6 {\\tiny 0.3}          & 18.7 {\\tiny 0.4}          & 29.2 {\\tiny 1.8}          & 20.7 {\\tiny 1.4}         \\\\\n",
      "Local Adaptor & \\xmark & 86.6 {\\tiny 0.3}          & 84.5 {\\tiny 1.8}          & 47.4 {\\tiny 5.4}          & 32.0 {\\tiny 2.3}          & 66.3 {\\tiny 0.5}          & 68.8 {\\tiny 0.5}          & 33.5 {\\tiny 0.5}          & 30.8 {\\tiny 0.8}          & 85.1 {\\tiny 0.8}          & 39.5 {\\tiny 2.8}         \\\\\n",
      "Ensemble      & \\xmark & 92.0 {\\tiny 0.1}          & 93.8 {\\tiny 0.5}          & 66.7 {\\tiny 5.3}          & 86.4 {\\tiny 0.4}          & {\\it 71.0 {\\tiny 2.8}}    & 46.4 {\\tiny 9.2}          & 42.4 {\\tiny 0.9}          & 41.7 {\\tiny 4.6}          & 86.2 {\\tiny 0.0}          & 43.7 {\\tiny 3.2}         \\\\\n",
      "Ensemble      & \\cmark & {\\bf 95.8 {\\tiny 0.3}}    & {\\bf 95.6 {\\tiny 0.3}}    & {\\bf 88.2 {\\tiny 1.4}}    & {\\bf 87.6 {\\tiny 1.3}}    & {\\bf 73.7 {\\tiny 0.2}}    & {\\bf 73.3 {\\tiny 0.1}}    & 45.0 {\\tiny 0.9}          & {\\bf 45.1 {\\tiny 0.8}}    & {\\bf 92.8 {\\tiny 0.3}}    & {\\bf 55.0 {\\tiny 0.4}}   \\\\\n",
      "FLoRAL(1\\%)   & \\xmark & 91.3 {\\tiny 0.6}          & 89.7 {\\tiny 3.2}          & 73.1 {\\tiny 3.7}          & 46.0 {\\tiny 9.9}          & 65.5 {\\tiny 0.4}          & 62.8 {\\tiny 8.8}          & 45.2 {\\tiny 0.3}          & {\\it 44.2 {\\tiny 0.9}}    & 81.3 {\\tiny 0.5}          & 52.2 {\\tiny 0.5}         \\\\\n",
      "FLoRAL(1\\%)   & \\cmark & 93.9 {\\tiny 0.8}          & 93.7 {\\tiny 0.2}          & {\\it 87.3 {\\tiny 1.5}}    & {\\it 87.6 {\\tiny 0.5}}    & 68.9 {\\tiny 0.2}          & {\\it 72.2 {\\tiny 0.2}}    & {\\bf 47.8 {\\tiny 0.9}}    & 44.1 {\\tiny 0.6}          & 82.4 {\\tiny 0.2}          & 53.1 {\\tiny 0.4}         \\\\\n",
      "FLoRAL(10\\%)  & \\xmark & 91.8 {\\tiny 1.0}          & 93.1 {\\tiny 0.9}          & 75.7 {\\tiny 2.3}          & 70.8 {\\tiny 7.1}          & 65.1 {\\tiny 0.3}          & 56.2 {\\tiny 5.5}          & 44.5 {\\tiny 0.4}          & 42.1 {\\tiny 0.2}          & {\\it 87.3 {\\tiny 0.3}}    & 51.2 {\\tiny 1.0}         \\\\\n",
      "FLoRAL(10\\%)  & \\cmark & {\\it 94.5 {\\tiny 0.6}}    & {\\it 94.2 {\\tiny 0.2}}    & 87.0 {\\tiny 0.7}          & 86.9 {\\tiny 0.5}          & 69.3 {\\tiny 0.5}          & 72.1 {\\tiny 0.5}          & {\\it 47.2 {\\tiny 0.3}}    & 42.7 {\\tiny 0.3}          & 86.6 {\\tiny 0.5}          & {\\it 53.9 {\\tiny 0.9}}   \\\\\n"
     ]
    }
   ],
   "source": [
    "COL_SEP = \" & \"\n",
    "DECIMALS = 1\n",
    "METHOD_COLNAME = \"Method\"\n",
    "OPTIMAL_ROUTER_COLNAME = \"Optimal $\\\\pi$\"\n",
    "FIELD_SIZE = 25\n",
    "EMPTY_FIELD = f\"{'-':{FIELD_SIZE}s}\"\n",
    "\n",
    "# TODO: mark best and second best metrics\n",
    "\n",
    "SORTED_METHODS = [\n",
    "    (\"FedAvg\", False),\n",
    "    (\"Local Adaptor\", False),\n",
    "    (\"Ensemble\", False),\n",
    "    (\"Ensemble\", True),\n",
    "    (\"FLoRAL(1%)\", False),\n",
    "    (\"FLoRAL(1%)\", True),\n",
    "    (\"FLoRAL(10%)\", False),\n",
    "    (\"FLoRAL(10%)\", True),\n",
    "]\n",
    "METHOD_FIELD_SIZE = max(len(m) for m, _ in SORTED_METHODS)\n",
    "\n",
    "SORTED_DATASETS = [\n",
    "    \"mnist_rotate\",\n",
    "    \"mnist_label_shift\",\n",
    "    \"mnist_rotate_reduced\",\n",
    "    \"mnist_label_shift_reduced\",\n",
    "    \"cifar10_rotate\",\n",
    "    \"cifar10_label_shift\",\n",
    "    \"cifar10_rotate_reduced\",\n",
    "    \"cifar10_label_shift_reduced\",\n",
    "    \"cifar100\",\n",
    "    \"cifar100_reduced\",\n",
    "    \"shakespeare_top1\",\n",
    "    \"shakespeare_top5\",\n",
    "]\n",
    "\n",
    "\n",
    "def metric_to_latex(mean, std, decimals=DECIMALS, field_size=FIELD_SIZE, marker=None):\n",
    "    acc_str =  f\"{mean:.{decimals}f}\" + \" {\\\\tiny \" + f\"{std:.{decimals}f}\" + \"}\"\n",
    "    if marker == 1:\n",
    "        acc_str = \"{\\\\bf \" + acc_str + \"}\"\n",
    "    elif marker == 2:\n",
    "        acc_str = \"{\\\\it \" + acc_str + \"}\"\n",
    "    return f\"{acc_str:{field_size}s}\"\n",
    "\n",
    "\n",
    "# Get best and second best methods\n",
    "def get_topk_methods(df_means, k=2):\n",
    "    assert k >= 0\n",
    "    df_means_by_method = df_means.set_index([METHOD_COLNAME, OPTIMAL_ROUTER_COLNAME])\n",
    "    top_methods = {}\n",
    "    for i in range(1, k+1):\n",
    "        top_methods[i] = df_means_by_method.idxmax()\n",
    "        for col in df_means_by_method.columns:\n",
    "            df_means_by_method.loc[top_methods[i][col], col] -= df_means_by_method.loc[top_methods[i][col], col]\n",
    "    return top_methods\n",
    "\n",
    "\n",
    "def get_marker(method, optimal_router, topk_methods, metric):\n",
    "    marker = None\n",
    "    for k, best_method in topk_methods.items():\n",
    "        if best_method[metric] == (method, optimal_router):\n",
    "            marker = k\n",
    "    return marker\n",
    "\n",
    "\n",
    "from collections import defaultdict\n",
    "results = defaultdict(dict)\n",
    "available_datasets = []\n",
    "for experiment, df in summary_dfs.items():\n",
    "    dataset = experiment[len(\"run_methods_\"):]\n",
    "    available_datasets.append(dataset)\n",
    "    if OPTIMAL_ROUTER_COLNAME not in df.columns:\n",
    "        df[OPTIMAL_ROUTER_COLNAME] = False\n",
    "    df_groupedby_seed = df.groupby([METHOD_COLNAME, OPTIMAL_ROUTER_COLNAME])\n",
    "    df_means = df_groupedby_seed.mean().reset_index()\n",
    "    df_stds = df_groupedby_seed.std().reset_index()\n",
    "    topk_methods = get_topk_methods(df_means, k=2)\n",
    "    for method, optimal_router in SORTED_METHODS:\n",
    "        mean_row = df_means[(df_means[METHOD_COLNAME] == method) & (df_means[OPTIMAL_ROUTER_COLNAME] == optimal_router)]\n",
    "        std_row = df_stds[(df_stds[METHOD_COLNAME] == method) & (df_stds[OPTIMAL_ROUTER_COLNAME] == optimal_router)]\n",
    "        if len(mean_row) == 0 or len(std_row) == 0:\n",
    "            continue\n",
    "        mean_row = mean_row.iloc[0]\n",
    "        std_row = std_row.iloc[0]\n",
    "        if dataset == \"shakespeare\":\n",
    "            # Top-1 accuracy\n",
    "            metric = \"accuracy_top1_distributed\"\n",
    "            mean, std = mean_row[metric], std_row[metric]\n",
    "            marker = get_marker(method, optimal_router, topk_methods, metric)\n",
    "            results[(method, optimal_router)][\"shakespeare_top1\"] = metric_to_latex(mean, std, marker=marker)\n",
    "            results[(method, True)][\"shakespeare_top1\"] = EMPTY_FIELD\n",
    "            # Top-5 accuracy\n",
    "            metric = \"accuracy_top5_distributed\"\n",
    "            mean, std = mean_row[metric], std_row[metric]\n",
    "            marker = get_marker(method, optimal_router, topk_methods, metric)\n",
    "            results[(method, optimal_router)][\"shakespeare_top5\"] = metric_to_latex(mean, std, marker=marker)\n",
    "            results[(method, True)][\"shakespeare_top5\"] = EMPTY_FIELD\n",
    "        elif \"synthetic\" in dataset:\n",
    "            metric = \"loss_distributed\"\n",
    "            mean, std = mean_row[metric], std_row[metric]\n",
    "            marker = get_marker(method, optimal_router, topk_methods, metric)\n",
    "            results[(method, optimal_router)][dataset] = metric_to_latex(mean, std, marker=marker)\n",
    "        else:\n",
    "            metric = \"acc_distributed\"\n",
    "            mean, std = mean_row[metric], std_row[metric]\n",
    "            marker = get_marker(method, optimal_router, topk_methods, metric)\n",
    "            results[(method, optimal_router)][dataset] = metric_to_latex(mean, std, marker=marker)\n",
    "\n",
    "for method, optimal_router in SORTED_METHODS:\n",
    "    method_latex = method.replace('%', '\\\\%')  # because they're treated as comments in latex\n",
    "    row = [f\"{method_latex:{METHOD_FIELD_SIZE}s}\", \"\\\\cmark\" if optimal_router else \"\\\\xmark\"]\n",
    "    for dataset in SORTED_DATASETS:\n",
    "        if dataset not in available_datasets:\n",
    "            continue\n",
    "        if (method, optimal_router) not in results:\n",
    "            row.append(metric_to_latex(float('nan'), float('nan')))\n",
    "        elif dataset not in results[(method, optimal_router)]:\n",
    "            row.append(metric_to_latex(float('nan'), float('nan')))\n",
    "        else:\n",
    "            row.append(results[(method, optimal_router)][dataset])\n",
    "    print(COL_SEP.join(row) + \"\\\\\\\\\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "torch",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.14"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
