{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [],
   "source": [
    "TASK = \"heatsink\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "field_name_maps = {\n",
    "    \"rolling\": {\n",
    "        \"all\": r\"\\textbf{All Fields Normalized Avg (-)}\",\n",
    "        \"deformation\": r\"\\textbf{Deformation (mm)}\",\n",
    "        \"nodes_LE\": r\"\\textbf{Logarithmic Strain ($\\mathbf{\\times 10^{-2}}$)}\",\n",
    "        \"nodes_PEEQ\": r\"\\textbf{Equivalent Plastic Strain ($\\mathbf{\\times 10^{-2}}$)}\",\n",
    "        \"nodes_mises_stress\": r\"\\textbf{Mises Stress (MPa)}\",\n",
    "        \"nodes_stresses\": r\"\\textbf{Stress (MPa)}\",\n",
    "        \"custom\": r\"\\textbf{Rel Custom Error (-)}\",\n",
    "        \"mae\": r\"\\textbf{MAE (-)}\",\n",
    "        \"r2\": r\"\\textbf{R2 (-)}\"\n",
    "    },\n",
    "    \"forming\": {\n",
    "        \"all\": r\"\\textbf{All Fields Normalized Avg (-)}\",\n",
    "        \"deformation\": r\"\\textbf{Deformation (mm)}\",\n",
    "        \"nodes_LE\": r\"\\textbf{Logarithmic Strain ($\\mathbf{\\times 10^{-2}}$)}\",\n",
    "        \"nodes_PEEQ\": r\"\\textbf{Equivalent Plastic Strain ($\\mathbf{\\times 10^{-2}}$)}\",\n",
    "        \"nodes_mises_stress\": r\"\\textbf{Mises Stress (MPa)}\",\n",
    "        \"nodes_stresses\": r\"\\textbf{Stress (MPa)}\",\n",
    "        \"custom\": r\"\\textbf{Rel Custom Error (-)}\",\n",
    "        \"mae\": r\"\\textbf{MAE (-)}\",\n",
    "        \"r2\": r\"\\textbf{R2 (-)}\"\n",
    "    },\n",
    "    \"motor\": {\n",
    "        \"all\": r\"\\textbf{All Fields Normalized Avg (-)}\",\n",
    "        \"deformation\": r\"\\textbf{Deformation (m)}\",\n",
    "        \"logarithmic_strain\": r\"\\textbf{Logarithmic Strain ($\\mathbf{\\times 10^{-2}}$)}\",\n",
    "        \"principal_strain\": r\"\\textbf{Principal Strain ($\\mathbf{\\times 10^{-2}}$)}\",\n",
    "        \"stress\": r\"\\textbf{Stress (MPa)}\",\n",
    "        \"stress_cauchy\": r\"\\textbf{Cauchy Stress (MPa)}\",\n",
    "        \"stress_mises\": r\"\\textbf{Mises Stress (MPa)}\",\n",
    "        \"stress_principal\": r\"\\textbf{Principal Stress (MPa)}\",\n",
    "        \"total_strain\": r\"\\textbf{Total Strain ($\\mathbf{\\times 10^{-2}}$)}\",\n",
    "        \"custom\": r\"\\textbf{Rel Custom Error (-)}\",\n",
    "        \"mae\": r\"\\textbf{MAE (-)}\",\n",
    "        \"r2\": r\"\\textbf{R2 (-)}\"\n",
    "    },\n",
    "    \"motor_geometric_pointnet\": {\n",
    "        \"all\": r\"\\textbf{All Fields Normalized Avg (-)}\",\n",
    "        \"deformation\": r\"\\textbf{Deformation (m)}\",\n",
    "        \"logarithmic_strain\": r\"\\textbf{Logarithmic Strain ($\\mathbf{\\times 10^{-2}}$)}\",\n",
    "        \"principal_strain\": r\"\\textbf{Principal Strain ($\\mathbf{\\times 10^{-2}}$)}\",\n",
    "        \"stress\": r\"\\textbf{Stress (MPa)}\",\n",
    "        \"stress_cauchy\": r\"\\textbf{Cauchy Stress (MPa)}\",\n",
    "        \"stress_mises\": r\"\\textbf{Mises Stress (MPa)}\",\n",
    "        \"stress_principal\": r\"\\textbf{Principal Stress (MPa)}\",\n",
    "        \"total_strain\": r\"\\textbf{Total Strain ($\\mathbf{\\times 10^{-2}}$)}\",\n",
    "        \"custom\": r\"\\textbf{Rel Custom Error (-)}\"\n",
    "    },\n",
    "    \"motor_2D\": {\n",
    "        \"all\": r\"\\textbf{All Fields Normalized Avg (-)}\",\n",
    "        \"deformation\": r\"\\textbf{Deformation (m)}\",\n",
    "        \"logarithmic_strain\": r\"\\textbf{Logarithmic Strain ($\\mathbf{\\times 10^{-2}}$)}\",\n",
    "        \"principal_strain\": r\"\\textbf{Principal Strain ($\\mathbf{\\times 10^{-2}}$)}\",\n",
    "        \"stress\": r\"\\textbf{Stress (MPa)}\",\n",
    "        \"stress_cauchy\": r\"\\textbf{Cauchy Stress (MPa)}\",\n",
    "        \"stress_mises\": r\"\\textbf{Mises Stress (MPa)}\",\n",
    "        \"stress_principal\": r\"\\textbf{Principal Stress (MPa)}\",\n",
    "        \"total_strain\": r\"\\textbf{Total Strain ($\\mathbf{\\times 10^{-2}}$)}\",\n",
    "        \"custom\": r\"\\textbf{Rel Custom Error (-)}\"\n",
    "    },\n",
    "    \"heatsink\": {\n",
    "        \"all\": r\"\\textbf{All Fields Normalized Avg (-)}\",\n",
    "        \"U\": r\"\\textbf{Velocity (m/s)}\",\n",
    "        \"p\": r\"\\textbf{Pressure (kPa)}\",\n",
    "        \"T\": r\"\\textbf{Temperature (K)}\",\n",
    "        \"custom\": r\"\\textbf{Rel Custom Error (-)}\",\n",
    "        \"mae\": r\"\\textbf{MAE (-)}\",\n",
    "        \"r2\": r\"\\textbf{R2 (-)}\"\n",
    "    }\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "<class 'pandas.core.frame.DataFrame'>\n",
      "RangeIndex: 40 entries, 0 to 39\n",
      "Data columns (total 20 columns):\n",
      " #   Column                          Non-Null Count  Dtype  \n",
      "---  ------                          --------------  -----  \n",
      " 0   model_name                      40 non-null     object \n",
      " 1   da_algorithm_name               40 non-null     object \n",
      " 2   model_selection_algorithm_name  40 non-null     object \n",
      " 3   seed                            40 non-null     int64  \n",
      " 4   test_loss_source                40 non-null     float64\n",
      " 5   test_loss_target                40 non-null     float64\n",
      " 6   test_loss_source_mae            40 non-null     float64\n",
      " 7   test_loss_source_r2             40 non-null     float64\n",
      " 8   test_loss_target_mae            40 non-null     float64\n",
      " 9   test_loss_target_r2             40 non-null     float64\n",
      " 10  test_loss_source_deformation    40 non-null     float64\n",
      " 11  test_loss_target_deformation    40 non-null     float64\n",
      " 12  test_loss_source_T              40 non-null     float64\n",
      " 13  test_loss_target_T              40 non-null     float64\n",
      " 14  test_loss_source_U              40 non-null     float64\n",
      " 15  test_loss_target_U              40 non-null     float64\n",
      " 16  test_loss_source_p              40 non-null     float64\n",
      " 17  test_loss_target_p              40 non-null     float64\n",
      " 18  test_loss_source_custom         40 non-null     float64\n",
      " 19  test_loss_target_custom         40 non-null     float64\n",
      "dtypes: float64(16), int64(1), object(3)\n",
      "memory usage: 6.4+ KB\n",
      "None\n"
     ]
    }
   ],
   "source": [
    "df = pd.read_pickle(f\"model_selection_results_rebuttal/results_{TASK}.pkl\")\n",
    "print(df.info())\n",
    "\n",
    "if TASK == \"heatsink\":\n",
    "    df = df.drop(columns=[\"test_loss_source_deformation\", \"test_loss_target_deformation\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "<class 'pandas.core.frame.DataFrame'>\n",
      "RangeIndex: 10 entries, 0 to 9\n",
      "Data columns (total 31 columns):\n",
      " #   Column                          Non-Null Count  Dtype  \n",
      "---  ------                          --------------  -----  \n",
      " 0   model_name                      10 non-null     object \n",
      " 1   da_algorithm_name               10 non-null     object \n",
      " 2   model_selection_algorithm_name  10 non-null     object \n",
      " 3   test_loss_source_mean           10 non-null     float64\n",
      " 4   test_loss_source_std            10 non-null     float64\n",
      " 5   test_loss_target_mean           10 non-null     float64\n",
      " 6   test_loss_target_std            10 non-null     float64\n",
      " 7   test_loss_source_mae_mean       10 non-null     float64\n",
      " 8   test_loss_source_mae_std        10 non-null     float64\n",
      " 9   test_loss_source_r2_mean        10 non-null     float64\n",
      " 10  test_loss_source_r2_std         10 non-null     float64\n",
      " 11  test_loss_target_mae_mean       10 non-null     float64\n",
      " 12  test_loss_target_mae_std        10 non-null     float64\n",
      " 13  test_loss_target_r2_mean        10 non-null     float64\n",
      " 14  test_loss_target_r2_std         10 non-null     float64\n",
      " 15  test_loss_source_T_mean         10 non-null     float64\n",
      " 16  test_loss_source_T_std          10 non-null     float64\n",
      " 17  test_loss_target_T_mean         10 non-null     float64\n",
      " 18  test_loss_target_T_std          10 non-null     float64\n",
      " 19  test_loss_source_U_mean         10 non-null     float64\n",
      " 20  test_loss_source_U_std          10 non-null     float64\n",
      " 21  test_loss_target_U_mean         10 non-null     float64\n",
      " 22  test_loss_target_U_std          10 non-null     float64\n",
      " 23  test_loss_source_p_mean         10 non-null     float64\n",
      " 24  test_loss_source_p_std          10 non-null     float64\n",
      " 25  test_loss_target_p_mean         10 non-null     float64\n",
      " 26  test_loss_target_p_std          10 non-null     float64\n",
      " 27  test_loss_source_custom_mean    10 non-null     float64\n",
      " 28  test_loss_source_custom_std     10 non-null     float64\n",
      " 29  test_loss_target_custom_mean    10 non-null     float64\n",
      " 30  test_loss_target_custom_std     10 non-null     float64\n",
      "dtypes: float64(28), object(3)\n",
      "memory usage: 2.6+ KB\n",
      "None\n"
     ]
    }
   ],
   "source": [
    "group_cols = [\"model_name\", \"da_algorithm_name\", \"model_selection_algorithm_name\"]\n",
    "\n",
    "# pick out all your test-loss columns\n",
    "loss_cols = [c for c in df.columns if c.startswith(\"test_loss_\")]\n",
    "\n",
    "# build agg dict: each loss_col → [mean, std]\n",
    "agg_dict = {col: [\"mean\", \"std\"] for col in loss_cols}\n",
    "\n",
    "# do the groupby-agg\n",
    "agg_df = (\n",
    "    df\n",
    "    .groupby(group_cols)[loss_cols]\n",
    "    .agg(agg_dict)\n",
    "    .reset_index()\n",
    ")\n",
    "\n",
    "# flatten the MultiIndex columns\n",
    "agg_df.columns = [\n",
    "    f\"{col}_{stat}\" if stat else col\n",
    "    for col, stat in agg_df.columns.to_flat_index()\n",
    "]\n",
    "\n",
    "print(agg_df.info())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "metadata": {},
   "outputs": [],
   "source": [
    "if TASK in [\"motor\", \"motor_2D\", \"motor_geometric_pointnet\"]:\n",
    "    for col in list(agg_df.columns):\n",
    "        if \"stress\" in col:\n",
    "            agg_df[col] = agg_df[col] / 1e6\n",
    "        if \"strain\" in col:\n",
    "            agg_df[col] = agg_df[col] * 1e2\n",
    "\n",
    "if TASK in [\"rolling\", \"forming\"]:\n",
    "    for col in list(agg_df.columns):\n",
    "        if (\"nodes_PEEQ\" in col) or (\"nodes_LE\" in col):\n",
    "            agg_df[col] = agg_df[col] * 1e2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.microsoft.datawrangler.viewer.v0+json": {
       "columns": [
        {
         "name": "index",
         "rawType": "int64",
         "type": "integer"
        },
        {
         "name": "model_name",
         "rawType": "object",
         "type": "string"
        },
        {
         "name": "da_algorithm_name",
         "rawType": "object",
         "type": "string"
        },
        {
         "name": "model_selection_algorithm_name",
         "rawType": "object",
         "type": "string"
        },
        {
         "name": "test_loss_source_mean",
         "rawType": "float64",
         "type": "float"
        },
        {
         "name": "test_loss_source_std",
         "rawType": "float64",
         "type": "float"
        },
        {
         "name": "test_loss_target_mean",
         "rawType": "float64",
         "type": "float"
        },
        {
         "name": "test_loss_target_std",
         "rawType": "float64",
         "type": "float"
        },
        {
         "name": "test_loss_source_mae_mean",
         "rawType": "float64",
         "type": "float"
        },
        {
         "name": "test_loss_source_mae_std",
         "rawType": "float64",
         "type": "float"
        },
        {
         "name": "test_loss_source_r2_mean",
         "rawType": "float64",
         "type": "float"
        },
        {
         "name": "test_loss_source_r2_std",
         "rawType": "float64",
         "type": "float"
        },
        {
         "name": "test_loss_target_mae_mean",
         "rawType": "float64",
         "type": "float"
        },
        {
         "name": "test_loss_target_mae_std",
         "rawType": "float64",
         "type": "float"
        },
        {
         "name": "test_loss_target_r2_mean",
         "rawType": "float64",
         "type": "float"
        },
        {
         "name": "test_loss_target_r2_std",
         "rawType": "float64",
         "type": "float"
        },
        {
         "name": "test_loss_source_T_mean",
         "rawType": "float64",
         "type": "float"
        },
        {
         "name": "test_loss_source_T_std",
         "rawType": "float64",
         "type": "float"
        },
        {
         "name": "test_loss_target_T_mean",
         "rawType": "float64",
         "type": "float"
        },
        {
         "name": "test_loss_target_T_std",
         "rawType": "float64",
         "type": "float"
        },
        {
         "name": "test_loss_source_U_mean",
         "rawType": "float64",
         "type": "float"
        },
        {
         "name": "test_loss_source_U_std",
         "rawType": "float64",
         "type": "float"
        },
        {
         "name": "test_loss_target_U_mean",
         "rawType": "float64",
         "type": "float"
        },
        {
         "name": "test_loss_target_U_std",
         "rawType": "float64",
         "type": "float"
        },
        {
         "name": "test_loss_source_p_mean",
         "rawType": "float64",
         "type": "float"
        },
        {
         "name": "test_loss_source_p_std",
         "rawType": "float64",
         "type": "float"
        },
        {
         "name": "test_loss_target_p_mean",
         "rawType": "float64",
         "type": "float"
        },
        {
         "name": "test_loss_target_p_std",
         "rawType": "float64",
         "type": "float"
        },
        {
         "name": "test_loss_source_custom_mean",
         "rawType": "float64",
         "type": "float"
        },
        {
         "name": "test_loss_source_custom_std",
         "rawType": "float64",
         "type": "float"
        },
        {
         "name": "test_loss_target_custom_mean",
         "rawType": "float64",
         "type": "float"
        },
        {
         "name": "test_loss_target_custom_std",
         "rawType": "float64",
         "type": "float"
        }
       ],
       "ref": "8bd22aa0-4f70-4115-a143-187378e06fe6",
       "rows": [
        [
         "0",
         "PointNet",
         "-",
         "-",
         "0.2891016751527786",
         "0.004505999575628755",
         "0.48385703563690186",
         "0.054345785928984744",
         "0.14378079771995544",
         "0.0024132678031363987",
         "0.8507727831602097",
         "0.003938774963325134",
         "0.25949442759156227",
         "0.033299200791008635",
         "0.5540147572755814",
         "0.1026783342118762",
         "6.149290442466736",
         "0.16704997903131757",
         "14.037474870681763",
         "1.8575235373078134",
         "0.029818573966622353",
         "0.0004441961947674037",
         "0.0429304288700223",
         "0.0038390397798681052",
         "193.43939971923828",
         "8.965762211431802",
         "909.8325805664062",
         "244.47465692401286",
         "0.009294896153733134",
         "0.00025536011951183505",
         "0.025059766601771116",
         "0.00528568166758918"
        ],
        [
         "1",
         "PointNet",
         "deep_coral",
         "DEV",
         "0.22920090705156326",
         "0.016376111790788156",
         "0.3434923365712166",
         "0.01452644303906583",
         "0.15282239392399788",
         "0.01253775491063637",
         "0.8356487303972244",
         "0.020623840421956332",
         "0.24977564439177513",
         "0.008764459000621175",
         "0.6022328436374664",
         "0.0388676250333527",
         "6.992524981498718",
         "1.0735793803978835",
         "16.670664072036743",
         "2.7169223947103522",
         "0.03165459679439664",
         "0.002210157822581708",
         "0.04047316499054432",
         "0.0028139325926236974",
         "178.70929718017578",
         "3.757132306763501",
         "635.996711730957",
         "199.52260979527733",
         "0.010536569636315107",
         "0.0019310754720491875",
         "0.03440989926457405",
         "0.010234601520611283"
        ],
        [
         "2",
         "PointNet",
         "deep_coral",
         "IWV",
         "0.23667749390006065",
         "0.022541035960968004",
         "0.3740626201033592",
         "0.03383052827911308",
         "0.1590474210679531",
         "0.016839669832833887",
         "0.8256534934043884",
         "0.029531001480008316",
         "0.2734804227948189",
         "0.02053145394649102",
         "0.5460159927606583",
         "0.09504832103900142",
         "7.431118488311768",
         "1.4680369165477345",
         "18.48963475227356",
         "4.264002808423925",
         "0.032566715497523546",
         "0.003199490041195637",
         "0.042388226836919785",
         "0.004846827389770375",
         "193.48199462890625",
         "10.639706382838765",
         "953.7728271484375",
         "382.20450274378203",
         "0.011851029237732291",
         "0.0033030528070386533",
         "0.0384215391241014",
         "0.014105032739005263"
        ],
        [
         "3",
         "PointNet",
         "deep_coral",
         "SB",
         "0.21869589760899544",
         "0.004283877544435395",
         "0.3751775994896889",
         "0.02998946995632927",
         "0.1453240066766739",
         "0.003884441795911093",
         "0.8494159430265427",
         "0.004706293602165335",
         "0.2700645364820957",
         "0.02053799801879847",
         "0.5347799956798553",
         "0.06922396105401342",
         "6.291512966156006",
         "0.2549446846120513",
         "17.338783979415894",
         "3.6302174300596706",
         "0.030138222966343164",
         "0.0005424060198892192",
         "0.04435272701084614",
         "0.003183663782830047",
         "189.00383758544922",
         "14.340157769887481",
         "781.2056274414062",
         "69.48523802725627",
         "0.009396887617185712",
         "0.0006219318071564316",
         "0.03495216555893421",
         "0.011380566848181997"
        ],
        [
         "4",
         "PointNet",
         "deep_coral",
         "TB",
         "0.2496483363211155",
         "0.019235805655915282",
         "0.3362131118774414",
         "0.005372906939398351",
         "0.16849126666784286",
         "0.015100980880267043",
         "0.8091689944267273",
         "0.02657193511932297",
         "0.24778571724891663",
         "0.006329830707799042",
         "0.631302610039711",
         "0.02172346831840371",
         "8.283781886100769",
         "1.1629547489957672",
         "16.940192699432373",
         "2.9734529522358404",
         "0.03439356293529272",
         "0.002589629591198228",
         "0.03836929798126221",
         "0.000723274986516257",
         "187.41442108154297",
         "14.548829417778197",
         "771.331428527832",
         "363.39979426522393",
         "0.013549751369282603",
         "0.003510595419348841",
         "0.03542296774685383",
         "0.010819928484127439"
        ],
        [
         "5",
         "Transolver",
         "-",
         "-",
         "0.23724132776260376",
         "0.002229109681381969",
         "0.46990063041448593",
         "0.048539175665313274",
         "0.11050163768231869",
         "0.0010410799113166318",
         "0.89455346763134",
         "0.002207358976316953",
         "0.2532827630639076",
         "0.027032731362966755",
         "0.6030274778604507",
         "0.0705954992777129",
         "4.155053973197937",
         "0.0297551172917311",
         "11.081872940063477",
         "2.352904456906681",
         "0.02383263921365142",
         "0.00022240419043824893",
         "0.04047305230051279",
         "0.004127527109791428",
         "256.3443412780762",
         "15.616403957089226",
         "1623.4956359863281",
         "210.07829084505377",
         "0.006504206685349345",
         "0.0002572161572435599",
         "0.013958288123831153",
         "0.006610002365148208"
        ],
        [
         "6",
         "Transolver",
         "deep_coral",
         "DEV",
         "0.17915914207696915",
         "0.002008413630352386",
         "0.3327941969037056",
         "0.016980647752941783",
         "0.11171585135161877",
         "0.002128895507038795",
         "0.8934242278337479",
         "0.0016900694572083556",
         "0.23954439163208008",
         "0.015293946665807537",
         "0.63944011926651",
         "0.04133209950329379",
         "4.235326170921326",
         "0.14515970006738962",
         "9.698442935943604",
         "1.064434575734366",
         "0.024166688323020935",
         "0.00032049288546609023",
         "0.038186537101864815",
         "0.0014308400815839686",
         "237.45753860473633",
         "4.949936814926264",
         "1600.520751953125",
         "197.42685347417404",
         "0.006569335935637355",
         "0.0003458554248002175",
         "0.010379871935583651",
         "0.00339926843076376"
        ],
        [
         "7",
         "Transolver",
         "deep_coral",
         "IWV",
         "0.17722012475132942",
         "0.0006668592749496813",
         "0.3364969938993454",
         "0.02160522001924487",
         "0.11021823063492775",
         "0.00045932606029254106",
         "0.8954248577356339",
         "0.000698552153163507",
         "0.24271417781710625",
         "0.018244980926034954",
         "0.6349333375692368",
         "0.04588438529319035",
         "4.174217224121094",
         "0.03688527695062672",
         "9.486577033996582",
         "1.0016771091627488",
         "0.023829253390431404",
         "8.768556044355869e-05",
         "0.038389863446354866",
         "0.001918215775269026",
         "243.13379669189453",
         "6.164715739614565",
         "1697.5262145996094",
         "199.2578766458922",
         "0.006503704586066306",
         "0.00015369962712400055",
         "0.008931965683586895",
         "0.0013768354454242392"
        ],
        [
         "8",
         "Transolver",
         "deep_coral",
         "SB",
         "0.17724523320794106",
         "0.0018856076020507882",
         "0.350361630320549",
         "0.0054839870251597875",
         "0.11038837395608425",
         "0.0018562688189186576",
         "0.8960041254758835",
         "0.001202484008557402",
         "0.254120297729969",
         "0.0031488028211485814",
         "0.6049044281244278",
         "0.011665209697267151",
         "4.1681026220321655",
         "0.10568931663967171",
         "10.381364822387695",
         "0.4560168141615545",
         "0.023693532682955265",
         "0.0001955321538958011",
         "0.039311052300035954",
         "0.0007129575035794555",
         "263.4605255126953",
         "26.325322379991682",
         "1812.385498046875",
         "179.9828280038763",
         "0.006571892532519996",
         "0.0004847521734904906",
         "0.011064078891649842",
         "0.002679423145098795"
        ],
        [
         "9",
         "Transolver",
         "deep_coral",
         "TB",
         "0.1776348389685154",
         "0.0006326854483082914",
         "0.3234787806868553",
         "0.01406324711885935",
         "0.10994589328765869",
         "0.0004608665837673354",
         "0.8943659663200378",
         "0.0011962493932817775",
         "0.229326993227005",
         "0.012110861013846197",
         "0.6754594296216965",
         "0.025127615643324917",
         "4.111499190330505",
         "0.039154539265250275",
         "9.09405517578125",
         "0.9773076521522676",
         "0.023952628951519728",
         "0.00011213588904806322",
         "0.037716200575232506",
         "0.0018897189183669366",
         "238.6129379272461",
         "12.179123429594158",
         "1523.0763244628906",
         "64.67971226883175",
         "0.00625076936557889",
         "0.0001165441529487587",
         "0.008356884471140802",
         "0.0019172797819504018"
        ]
       ],
       "shape": {
        "columns": 31,
        "rows": 10
       }
      },
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>model_name</th>\n",
       "      <th>da_algorithm_name</th>\n",
       "      <th>model_selection_algorithm_name</th>\n",
       "      <th>test_loss_source_mean</th>\n",
       "      <th>test_loss_source_std</th>\n",
       "      <th>test_loss_target_mean</th>\n",
       "      <th>test_loss_target_std</th>\n",
       "      <th>test_loss_source_mae_mean</th>\n",
       "      <th>test_loss_source_mae_std</th>\n",
       "      <th>test_loss_source_r2_mean</th>\n",
       "      <th>...</th>\n",
       "      <th>test_loss_target_U_mean</th>\n",
       "      <th>test_loss_target_U_std</th>\n",
       "      <th>test_loss_source_p_mean</th>\n",
       "      <th>test_loss_source_p_std</th>\n",
       "      <th>test_loss_target_p_mean</th>\n",
       "      <th>test_loss_target_p_std</th>\n",
       "      <th>test_loss_source_custom_mean</th>\n",
       "      <th>test_loss_source_custom_std</th>\n",
       "      <th>test_loss_target_custom_mean</th>\n",
       "      <th>test_loss_target_custom_std</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>PointNet</td>\n",
       "      <td>-</td>\n",
       "      <td>-</td>\n",
       "      <td>0.289102</td>\n",
       "      <td>0.004506</td>\n",
       "      <td>0.483857</td>\n",
       "      <td>0.054346</td>\n",
       "      <td>0.143781</td>\n",
       "      <td>0.002413</td>\n",
       "      <td>0.850773</td>\n",
       "      <td>...</td>\n",
       "      <td>0.042930</td>\n",
       "      <td>0.003839</td>\n",
       "      <td>193.439400</td>\n",
       "      <td>8.965762</td>\n",
       "      <td>909.832581</td>\n",
       "      <td>244.474657</td>\n",
       "      <td>0.009295</td>\n",
       "      <td>0.000255</td>\n",
       "      <td>0.025060</td>\n",
       "      <td>0.005286</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>PointNet</td>\n",
       "      <td>deep_coral</td>\n",
       "      <td>DEV</td>\n",
       "      <td>0.229201</td>\n",
       "      <td>0.016376</td>\n",
       "      <td>0.343492</td>\n",
       "      <td>0.014526</td>\n",
       "      <td>0.152822</td>\n",
       "      <td>0.012538</td>\n",
       "      <td>0.835649</td>\n",
       "      <td>...</td>\n",
       "      <td>0.040473</td>\n",
       "      <td>0.002814</td>\n",
       "      <td>178.709297</td>\n",
       "      <td>3.757132</td>\n",
       "      <td>635.996712</td>\n",
       "      <td>199.522610</td>\n",
       "      <td>0.010537</td>\n",
       "      <td>0.001931</td>\n",
       "      <td>0.034410</td>\n",
       "      <td>0.010235</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>PointNet</td>\n",
       "      <td>deep_coral</td>\n",
       "      <td>IWV</td>\n",
       "      <td>0.236677</td>\n",
       "      <td>0.022541</td>\n",
       "      <td>0.374063</td>\n",
       "      <td>0.033831</td>\n",
       "      <td>0.159047</td>\n",
       "      <td>0.016840</td>\n",
       "      <td>0.825653</td>\n",
       "      <td>...</td>\n",
       "      <td>0.042388</td>\n",
       "      <td>0.004847</td>\n",
       "      <td>193.481995</td>\n",
       "      <td>10.639706</td>\n",
       "      <td>953.772827</td>\n",
       "      <td>382.204503</td>\n",
       "      <td>0.011851</td>\n",
       "      <td>0.003303</td>\n",
       "      <td>0.038422</td>\n",
       "      <td>0.014105</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>PointNet</td>\n",
       "      <td>deep_coral</td>\n",
       "      <td>SB</td>\n",
       "      <td>0.218696</td>\n",
       "      <td>0.004284</td>\n",
       "      <td>0.375178</td>\n",
       "      <td>0.029989</td>\n",
       "      <td>0.145324</td>\n",
       "      <td>0.003884</td>\n",
       "      <td>0.849416</td>\n",
       "      <td>...</td>\n",
       "      <td>0.044353</td>\n",
       "      <td>0.003184</td>\n",
       "      <td>189.003838</td>\n",
       "      <td>14.340158</td>\n",
       "      <td>781.205627</td>\n",
       "      <td>69.485238</td>\n",
       "      <td>0.009397</td>\n",
       "      <td>0.000622</td>\n",
       "      <td>0.034952</td>\n",
       "      <td>0.011381</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>PointNet</td>\n",
       "      <td>deep_coral</td>\n",
       "      <td>TB</td>\n",
       "      <td>0.249648</td>\n",
       "      <td>0.019236</td>\n",
       "      <td>0.336213</td>\n",
       "      <td>0.005373</td>\n",
       "      <td>0.168491</td>\n",
       "      <td>0.015101</td>\n",
       "      <td>0.809169</td>\n",
       "      <td>...</td>\n",
       "      <td>0.038369</td>\n",
       "      <td>0.000723</td>\n",
       "      <td>187.414421</td>\n",
       "      <td>14.548829</td>\n",
       "      <td>771.331429</td>\n",
       "      <td>363.399794</td>\n",
       "      <td>0.013550</td>\n",
       "      <td>0.003511</td>\n",
       "      <td>0.035423</td>\n",
       "      <td>0.010820</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5</th>\n",
       "      <td>Transolver</td>\n",
       "      <td>-</td>\n",
       "      <td>-</td>\n",
       "      <td>0.237241</td>\n",
       "      <td>0.002229</td>\n",
       "      <td>0.469901</td>\n",
       "      <td>0.048539</td>\n",
       "      <td>0.110502</td>\n",
       "      <td>0.001041</td>\n",
       "      <td>0.894553</td>\n",
       "      <td>...</td>\n",
       "      <td>0.040473</td>\n",
       "      <td>0.004128</td>\n",
       "      <td>256.344341</td>\n",
       "      <td>15.616404</td>\n",
       "      <td>1623.495636</td>\n",
       "      <td>210.078291</td>\n",
       "      <td>0.006504</td>\n",
       "      <td>0.000257</td>\n",
       "      <td>0.013958</td>\n",
       "      <td>0.006610</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>6</th>\n",
       "      <td>Transolver</td>\n",
       "      <td>deep_coral</td>\n",
       "      <td>DEV</td>\n",
       "      <td>0.179159</td>\n",
       "      <td>0.002008</td>\n",
       "      <td>0.332794</td>\n",
       "      <td>0.016981</td>\n",
       "      <td>0.111716</td>\n",
       "      <td>0.002129</td>\n",
       "      <td>0.893424</td>\n",
       "      <td>...</td>\n",
       "      <td>0.038187</td>\n",
       "      <td>0.001431</td>\n",
       "      <td>237.457539</td>\n",
       "      <td>4.949937</td>\n",
       "      <td>1600.520752</td>\n",
       "      <td>197.426853</td>\n",
       "      <td>0.006569</td>\n",
       "      <td>0.000346</td>\n",
       "      <td>0.010380</td>\n",
       "      <td>0.003399</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>7</th>\n",
       "      <td>Transolver</td>\n",
       "      <td>deep_coral</td>\n",
       "      <td>IWV</td>\n",
       "      <td>0.177220</td>\n",
       "      <td>0.000667</td>\n",
       "      <td>0.336497</td>\n",
       "      <td>0.021605</td>\n",
       "      <td>0.110218</td>\n",
       "      <td>0.000459</td>\n",
       "      <td>0.895425</td>\n",
       "      <td>...</td>\n",
       "      <td>0.038390</td>\n",
       "      <td>0.001918</td>\n",
       "      <td>243.133797</td>\n",
       "      <td>6.164716</td>\n",
       "      <td>1697.526215</td>\n",
       "      <td>199.257877</td>\n",
       "      <td>0.006504</td>\n",
       "      <td>0.000154</td>\n",
       "      <td>0.008932</td>\n",
       "      <td>0.001377</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>8</th>\n",
       "      <td>Transolver</td>\n",
       "      <td>deep_coral</td>\n",
       "      <td>SB</td>\n",
       "      <td>0.177245</td>\n",
       "      <td>0.001886</td>\n",
       "      <td>0.350362</td>\n",
       "      <td>0.005484</td>\n",
       "      <td>0.110388</td>\n",
       "      <td>0.001856</td>\n",
       "      <td>0.896004</td>\n",
       "      <td>...</td>\n",
       "      <td>0.039311</td>\n",
       "      <td>0.000713</td>\n",
       "      <td>263.460526</td>\n",
       "      <td>26.325322</td>\n",
       "      <td>1812.385498</td>\n",
       "      <td>179.982828</td>\n",
       "      <td>0.006572</td>\n",
       "      <td>0.000485</td>\n",
       "      <td>0.011064</td>\n",
       "      <td>0.002679</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>9</th>\n",
       "      <td>Transolver</td>\n",
       "      <td>deep_coral</td>\n",
       "      <td>TB</td>\n",
       "      <td>0.177635</td>\n",
       "      <td>0.000633</td>\n",
       "      <td>0.323479</td>\n",
       "      <td>0.014063</td>\n",
       "      <td>0.109946</td>\n",
       "      <td>0.000461</td>\n",
       "      <td>0.894366</td>\n",
       "      <td>...</td>\n",
       "      <td>0.037716</td>\n",
       "      <td>0.001890</td>\n",
       "      <td>238.612938</td>\n",
       "      <td>12.179123</td>\n",
       "      <td>1523.076324</td>\n",
       "      <td>64.679712</td>\n",
       "      <td>0.006251</td>\n",
       "      <td>0.000117</td>\n",
       "      <td>0.008357</td>\n",
       "      <td>0.001917</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>10 rows × 31 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "   model_name da_algorithm_name model_selection_algorithm_name  \\\n",
       "0    PointNet                 -                              -   \n",
       "1    PointNet        deep_coral                            DEV   \n",
       "2    PointNet        deep_coral                            IWV   \n",
       "3    PointNet        deep_coral                             SB   \n",
       "4    PointNet        deep_coral                             TB   \n",
       "5  Transolver                 -                              -   \n",
       "6  Transolver        deep_coral                            DEV   \n",
       "7  Transolver        deep_coral                            IWV   \n",
       "8  Transolver        deep_coral                             SB   \n",
       "9  Transolver        deep_coral                             TB   \n",
       "\n",
       "   test_loss_source_mean  test_loss_source_std  test_loss_target_mean  \\\n",
       "0               0.289102              0.004506               0.483857   \n",
       "1               0.229201              0.016376               0.343492   \n",
       "2               0.236677              0.022541               0.374063   \n",
       "3               0.218696              0.004284               0.375178   \n",
       "4               0.249648              0.019236               0.336213   \n",
       "5               0.237241              0.002229               0.469901   \n",
       "6               0.179159              0.002008               0.332794   \n",
       "7               0.177220              0.000667               0.336497   \n",
       "8               0.177245              0.001886               0.350362   \n",
       "9               0.177635              0.000633               0.323479   \n",
       "\n",
       "   test_loss_target_std  test_loss_source_mae_mean  test_loss_source_mae_std  \\\n",
       "0              0.054346                   0.143781                  0.002413   \n",
       "1              0.014526                   0.152822                  0.012538   \n",
       "2              0.033831                   0.159047                  0.016840   \n",
       "3              0.029989                   0.145324                  0.003884   \n",
       "4              0.005373                   0.168491                  0.015101   \n",
       "5              0.048539                   0.110502                  0.001041   \n",
       "6              0.016981                   0.111716                  0.002129   \n",
       "7              0.021605                   0.110218                  0.000459   \n",
       "8              0.005484                   0.110388                  0.001856   \n",
       "9              0.014063                   0.109946                  0.000461   \n",
       "\n",
       "   test_loss_source_r2_mean  ...  test_loss_target_U_mean  \\\n",
       "0                  0.850773  ...                 0.042930   \n",
       "1                  0.835649  ...                 0.040473   \n",
       "2                  0.825653  ...                 0.042388   \n",
       "3                  0.849416  ...                 0.044353   \n",
       "4                  0.809169  ...                 0.038369   \n",
       "5                  0.894553  ...                 0.040473   \n",
       "6                  0.893424  ...                 0.038187   \n",
       "7                  0.895425  ...                 0.038390   \n",
       "8                  0.896004  ...                 0.039311   \n",
       "9                  0.894366  ...                 0.037716   \n",
       "\n",
       "   test_loss_target_U_std  test_loss_source_p_mean  test_loss_source_p_std  \\\n",
       "0                0.003839               193.439400                8.965762   \n",
       "1                0.002814               178.709297                3.757132   \n",
       "2                0.004847               193.481995               10.639706   \n",
       "3                0.003184               189.003838               14.340158   \n",
       "4                0.000723               187.414421               14.548829   \n",
       "5                0.004128               256.344341               15.616404   \n",
       "6                0.001431               237.457539                4.949937   \n",
       "7                0.001918               243.133797                6.164716   \n",
       "8                0.000713               263.460526               26.325322   \n",
       "9                0.001890               238.612938               12.179123   \n",
       "\n",
       "   test_loss_target_p_mean  test_loss_target_p_std  \\\n",
       "0               909.832581              244.474657   \n",
       "1               635.996712              199.522610   \n",
       "2               953.772827              382.204503   \n",
       "3               781.205627               69.485238   \n",
       "4               771.331429              363.399794   \n",
       "5              1623.495636              210.078291   \n",
       "6              1600.520752              197.426853   \n",
       "7              1697.526215              199.257877   \n",
       "8              1812.385498              179.982828   \n",
       "9              1523.076324               64.679712   \n",
       "\n",
       "   test_loss_source_custom_mean  test_loss_source_custom_std  \\\n",
       "0                      0.009295                     0.000255   \n",
       "1                      0.010537                     0.001931   \n",
       "2                      0.011851                     0.003303   \n",
       "3                      0.009397                     0.000622   \n",
       "4                      0.013550                     0.003511   \n",
       "5                      0.006504                     0.000257   \n",
       "6                      0.006569                     0.000346   \n",
       "7                      0.006504                     0.000154   \n",
       "8                      0.006572                     0.000485   \n",
       "9                      0.006251                     0.000117   \n",
       "\n",
       "   test_loss_target_custom_mean  test_loss_target_custom_std  \n",
       "0                      0.025060                     0.005286  \n",
       "1                      0.034410                     0.010235  \n",
       "2                      0.038422                     0.014105  \n",
       "3                      0.034952                     0.011381  \n",
       "4                      0.035423                     0.010820  \n",
       "5                      0.013958                     0.006610  \n",
       "6                      0.010380                     0.003399  \n",
       "7                      0.008932                     0.001377  \n",
       "8                      0.011064                     0.002679  \n",
       "9                      0.008357                     0.001917  \n",
       "\n",
       "[10 rows x 31 columns]"
      ]
     },
     "execution_count": 34,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "agg_df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Index(['model_name', 'da_algorithm_name', 'model_selection_algorithm_name',\n",
       "       'test_loss_source_mean', 'test_loss_source_std',\n",
       "       'test_loss_target_mean', 'test_loss_target_std',\n",
       "       'test_loss_source_mae_mean', 'test_loss_source_mae_std',\n",
       "       'test_loss_source_r2_mean', 'test_loss_source_r2_std',\n",
       "       'test_loss_target_mae_mean', 'test_loss_target_mae_std',\n",
       "       'test_loss_target_r2_mean', 'test_loss_target_r2_std',\n",
       "       'test_loss_source_T_mean', 'test_loss_source_T_std',\n",
       "       'test_loss_target_T_mean', 'test_loss_target_T_std',\n",
       "       'test_loss_source_U_mean', 'test_loss_source_U_std',\n",
       "       'test_loss_target_U_mean', 'test_loss_target_U_std',\n",
       "       'test_loss_source_p_mean', 'test_loss_source_p_std',\n",
       "       'test_loss_target_p_mean', 'test_loss_target_p_std',\n",
       "       'test_loss_source_custom_mean', 'test_loss_source_custom_std',\n",
       "       'test_loss_target_custom_mean', 'test_loss_target_custom_std'],\n",
       "      dtype='object')"
      ]
     },
     "execution_count": 35,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "agg_df.columns"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import re\n",
    "\n",
    "def df_to_latex_table_with_highlights(\n",
    "    df: pd.DataFrame,\n",
    "    caption: str,\n",
    "    label: str,\n",
    "    float_fmt: str = \"{:.3f}\",\n",
    "    threshold_factor: float = 100.0,\n",
    "    da_name_map: dict = None,\n",
    "    field_name_map: dict = None,\n",
    "    exclude_selection: str = \"TB\",\n",
    "    # Shading controls\n",
    "    shade_best: bool = True,\n",
    "    shade_baseline: bool = True,\n",
    "    best_color_hex: str = \"DFF0D8\",      # light green\n",
    "    baseline_color_hex: str = \"FFF4CC\",  # light beige\n",
    "    baseline_da_markers = (\"-\", \"\", None),\n",
    "    baseline_sel_markers = (\"-\", \"\", None),\n",
    "    include_color_definitions: bool = True,\n",
    ") -> str:\n",
    "    \"\"\"\n",
    "    LaTeX table with:\n",
    "      • mean±std metric cells or ★ for outliers,\n",
    "      • multirow by model with cmidrule separators,\n",
    "      • global lowest TARGET 'all' is bold+underlined (metric cell),\n",
    "      • per-model lowest TARGET 'all' is underlined (metric cell),\n",
    "      • DA + Model Selection also bold+underlined on the global best row,\n",
    "      • Model name bold+underlined if the model contains the global best,\n",
    "      • NEW: shade best row per model (green) and baseline row (beige),\n",
    "             but DO NOT shade the first (Model) column.\n",
    "    \"\"\"\n",
    "    if da_name_map is None:\n",
    "        da_name_map = {\"deep_coral\": \"Deep Coral\", \"cmd\": \"CMD\", \"DANN\": \"DANN\"}\n",
    "    if field_name_map is None:\n",
    "        field_name_map = {\n",
    "            \"all\": \"All Fields Normalized Avg (-)\",\n",
    "            \"deformation\": \"Deformation (mm)\",\n",
    "            \"nodes_LE\": \"Logarithmic Strain (-)\",\n",
    "            \"nodes_PEEQ\": \"Equivalent Plastic Strain (-)\",\n",
    "            \"nodes_mises_stress\": \"Mises Stress (Pa)\",\n",
    "            \"nodes_stresses\": \"Stress (Pa)\",\n",
    "        }\n",
    "\n",
    "    # Parse field columns\n",
    "    field_map = {}\n",
    "    pat = re.compile(r\"^test_loss_(source|target)(?:_(.+?))?_(mean|std)$\")\n",
    "    for c in df.columns:\n",
    "        m = pat.match(c)\n",
    "        if m:\n",
    "            dom, fld, stat = m.groups()\n",
    "            base = fld or \"all\"\n",
    "            field_map.setdefault(base, {}).setdefault(dom.upper(), {})[stat] = c\n",
    "    if \"all\" not in field_map or \"TARGET\" not in field_map[\"all\"]:\n",
    "        raise ValueError(\"Missing 'all' TARGET mean/std columns.\")\n",
    "\n",
    "    # Thresholds for ★\n",
    "    mean_cols = [c for c in df.columns if c.endswith(\"_mean\")]\n",
    "    med = df[mean_cols].median()\n",
    "\n",
    "    all_tgt_mean = field_map[\"all\"][\"TARGET\"][\"mean\"]\n",
    "\n",
    "    # Bests (excluding certain selection rows)\n",
    "    ok_mask = df[\"model_selection_algorithm_name\"] != exclude_selection\n",
    "    global_min = df[ok_mask][all_tgt_mean].min()\n",
    "    per_model_min = (\n",
    "        df[ok_mask].groupby(\"model_name\")[all_tgt_mean].min().to_dict()\n",
    "    )\n",
    "\n",
    "    fields = sorted(field_map.keys(), key=lambda x: (x != \"all\", x))\n",
    "    nf = len(fields)\n",
    "    total_cols = 3 + 2 * nf\n",
    "\n",
    "    lines = [\n",
    "        r\"\\begin{table}[h]\",\n",
    "        r\"  \\centering\",\n",
    "        f\"  \\\\caption{{{caption}}}\",\n",
    "        f\"  \\\\label{{{label}}}\",\n",
    "        r\"  \\resizebox{\\textwidth}{!}{%\",\n",
    "    ]\n",
    "    if include_color_definitions:\n",
    "        lines += [\n",
    "            r\"  % Requires \\usepackage[table]{xcolor}\",\n",
    "            fr\"  \\definecolor{{bestrow}}{{HTML}}{{{best_color_hex}}}\",\n",
    "            fr\"  \\definecolor{{baselinerow}}{{HTML}}{{{baseline_color_hex}}}\",\n",
    "        ]\n",
    "    lines += [\n",
    "        \"  \\\\begin{tabular}{\" + \"lll\" + \"c\" * (2 * nf) + \"}\",\n",
    "        \"    \\\\toprule\",\n",
    "    ]\n",
    "\n",
    "    # Header rows\n",
    "    hdr1 = (\n",
    "        r\"    \\multirow{2}{*}{\\textbf{Model}}\"\n",
    "        r\" & \\multirow{2}{*}{\\makecell{\\textbf{DA}\\\\ \\textbf{Algorithm}}}\"\n",
    "        r\" & \\multirow{2}{*}{\\makecell{\\textbf{Model}\\\\ \\textbf{Selection}}}\"\n",
    "    )\n",
    "    for f in fields:\n",
    "        hdr1 += f\" & \\\\multicolumn{{2}}{{c}}{{{field_name_map.get(f, f)}}}\"\n",
    "    hdr1 += r\" \\\\\"\n",
    "    lines.append(hdr1)\n",
    "\n",
    "    cm = \"    \"\n",
    "    for i in range(nf):\n",
    "        cm += f\"\\\\cmidrule(lr){{{4 + 2*i}-{5 + 2*i}}} \"\n",
    "    lines.append(cm.strip())\n",
    "\n",
    "    hdr2 = \"      &   &  & \" + \" & \".join([r\"\\textbf{SRC} & \\textbf{TGT}\"] * nf) + r\" \\\\\"\n",
    "    lines.extend([hdr2, \"    \\\\midrule\"])\n",
    "\n",
    "    # Body\n",
    "    for mi, model in enumerate(df[\"model_name\"].unique()):\n",
    "        sub = df[df[\"model_name\"] == model].sort_values(\n",
    "            [\"da_algorithm_name\", \"model_selection_algorithm_name\"]\n",
    "        )\n",
    "        n = len(sub)\n",
    "\n",
    "        # Does this model contain the global best row?\n",
    "        model_has_global = any(\n",
    "            (sub[\"model_selection_algorithm_name\"] != exclude_selection)\n",
    "            & (sub[all_tgt_mean] == global_min)\n",
    "        )\n",
    "\n",
    "        prev_da = None\n",
    "        for row_i, (idx, row) in enumerate(sub.iterrows()):\n",
    "            if prev_da is not None and row[\"da_algorithm_name\"] != prev_da:\n",
    "                lines.append(f\"    \\\\cmidrule(lr){{2-{total_cols}}}\")\n",
    "\n",
    "            da_raw = row[\"da_algorithm_name\"]\n",
    "            da_disp = da_name_map.get(da_raw, da_raw)\n",
    "            sel = row[\"model_selection_algorithm_name\"]\n",
    "\n",
    "            # Flags\n",
    "            is_global = (sel != exclude_selection) and (row[all_tgt_mean] == global_min)\n",
    "            is_model_best = (sel != exclude_selection) and (row[all_tgt_mean] == per_model_min.get(model))\n",
    "            is_baseline = (da_raw in baseline_da_markers) and (sel in baseline_sel_markers)\n",
    "\n",
    "            # Decide shade name (None, \"bestrow\", \"baselinerow\")\n",
    "            shade_name = None\n",
    "            if shade_best and is_model_best:\n",
    "                shade_name = \"bestrow\"\n",
    "            elif shade_baseline and is_baseline:\n",
    "                shade_name = \"baselinerow\"\n",
    "\n",
    "            # First column: Model (no shading)\n",
    "            if row_i == 0:\n",
    "                model_cell = model\n",
    "                if model_has_global:\n",
    "                    model_cell = f\"\\\\underline{{\\\\textbf{{{model_cell}}}}}\"\n",
    "                line = f\"    \\\\multirow{{{n}}}{{*}}{{{model_cell}}} & \"\n",
    "            else:\n",
    "                line = \"    & \"\n",
    "\n",
    "            # Helper to prefix shading for non-first columns\n",
    "            def shade(cell_text: str) -> str:\n",
    "                if shade_name:\n",
    "                    return fr\"\\cellcolor{{{shade_name}}}\" + cell_text\n",
    "                return cell_text\n",
    "\n",
    "            # DA & selection (apply shading to these cells, not the model)\n",
    "            if is_global:\n",
    "                da_cell = shade(f\"\\\\underline{{\\\\textbf{{{da_disp}}}}}\")\n",
    "                sel_cell = shade(f\"\\\\underline{{\\\\textbf{{{sel}}}}}\")\n",
    "            elif is_model_best:\n",
    "                da_cell = shade(f\"\\\\underline{{{da_disp}}}\")\n",
    "                sel_cell = shade(f\"\\\\underline{{{sel}}}\")\n",
    "            else:\n",
    "                da_cell = shade(da_disp)\n",
    "                sel_cell = shade(sel)\n",
    "\n",
    "            line += f\"{da_cell} & {sel_cell}\"\n",
    "\n",
    "            # Metrics (each metric cell gets shading prefix if needed)\n",
    "            for f in fields:\n",
    "                for dom in (\"SOURCE\", \"TARGET\"):\n",
    "                    mc = field_map[f][dom][\"mean\"]\n",
    "                    sc = field_map[f][dom][\"std\"]\n",
    "                    mv, sv = row[mc], row[sc]\n",
    "                    if mv > threshold_factor * med[mc]:\n",
    "                        cell = shade(r\"$\\star$\")\n",
    "                    else:\n",
    "                        txt = f\"{float_fmt.format(mv)}(\\\\pm{float_fmt.format(sv)})\"\n",
    "                        if f == \"all\" and dom == \"TARGET\":\n",
    "                            if is_global:\n",
    "                                cell = shade(f\"$\\\\underline{{\\\\mathbf{{{txt}}}}}$\")\n",
    "                            elif is_model_best:\n",
    "                                cell = shade(f\"$\\\\underline{{{txt}}}$\")\n",
    "                            else:\n",
    "                                cell = shade(f\"${txt}$\")\n",
    "                        else:\n",
    "                            cell = shade(f\"${txt}$\")\n",
    "                    line += f\" & {cell}\"\n",
    "\n",
    "            line += r\" \\\\\"\n",
    "            lines.append(line)\n",
    "            prev_da = row[\"da_algorithm_name\"]\n",
    "\n",
    "        lines.append(\"    \\\\midrule\" if mi < df[\"model_name\"].nunique() - 1 else \"    \\\\bottomrule\")\n",
    "\n",
    "    lines.extend([r\"  \\end{tabular}\", r\"  }\", r\"\\end{table}\"])\n",
    "    return \"\\n\".join(lines)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "table = (df_to_latex_table_with_highlights(\n",
    "    agg_df,\n",
    "    caption=\"Performance across different fields.\",\n",
    "    label=\"tab:fixed\",\n",
    "    da_name_map={\"deep_coral\":\"Deep Coral\",\"cmd\":\"CMD\",\"DANN\":\"DANN\"},\n",
    "    field_name_map=field_name_maps[TASK],\n",
    "    threshold_factor=100\n",
    "))\n",
    "\n",
    "print(table)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "simshift",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
