{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6a0b5f08-a139-4e4d-a987-47f7de3569e0",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "from pathlib import Path\n",
    "import re\n",
    "import seaborn as sns\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "\n",
    "from helper import print_interesting_columns, add_additional_info, get_relative_performances, save_or_show, set_ylims_with_margin\n",
    "from constants import BASE_PATH_PROJECT"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "63c806d4-5845-4072-8356-7e9688062822",
   "metadata": {},
   "outputs": [],
   "source": [
    "storing_path = BASE_PATH_PROJECT / \"results_iclr_exp/plots/appendix_dim_heads\"\n",
    "SAVE = False\n",
    "if SAVE:\n",
    "    storing_path.mkdir(parents=True, exist_ok=True)\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ae1424c4-2b08-454f-9ab5-03db3ea824d5",
   "metadata": {},
   "outputs": [],
   "source": [
    "ablation_results = pd.read_pickle(BASE_PATH_PROJECT / 'results_ablation/aggregated/ablation_results_of_030925.pkl')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "93137090-b99c-41a1-9435-8a370bc8b80c",
   "metadata": {},
   "outputs": [],
   "source": [
    "allowed_ds = ['wds/country211', \n",
    "              'wds/vtab/pets',\n",
    "              'wds/vtab/cifar100', \n",
    "              'wds/vtab/dtd']\n",
    "\n",
    "allowed_mids = [\n",
    "     '[\"dinov2-vit-base-p14_cls@norm\"]',\n",
    "     '[\"dinov2-vit-base-p14_ap@blocks.3.norm2\", \"dinov2-vit-base-p14_ap@blocks.6.norm2\", \"dinov2-vit-base-p14_ap@blocks.9.norm2\", \"dinov2-vit-base-p14_ap@norm\", \"dinov2-vit-base-p14_cls@blocks.3.norm2\", \"dinov2-vit-base-p14_cls@blocks.6.norm2\", \"dinov2-vit-base-p14_cls@blocks.9.norm2\", \"dinov2-vit-base-p14_cls@norm\"]',\n",
    "     '[\"dinov2-vit-base-p14_ap@blocks.1.norm2\", \"dinov2-vit-base-p14_ap@blocks.10.norm2\", \"dinov2-vit-base-p14_ap@blocks.11.norm2\", \"dinov2-vit-base-p14_ap@blocks.2.norm2\", \"dinov2-vit-base-p14_ap@blocks.3.norm2\", \"dinov2-vit-base-p14_ap@blocks.4.norm2\", \"dinov2-vit-base-p14_ap@blocks.5.norm2\", \"dinov2-vit-base-p14_ap@blocks.6.norm2\", \"dinov2-vit-base-p14_ap@blocks.7.norm2\", \"dinov2-vit-base-p14_ap@blocks.8.norm2\", \"dinov2-vit-base-p14_ap@blocks.9.norm2\", \"dinov2-vit-base-p14_ap@norm\", \"dinov2-vit-base-p14_cls@blocks.1.norm2\", \"dinov2-vit-base-p14_cls@blocks.10.norm2\", \"dinov2-vit-base-p14_cls@blocks.11.norm2\", \"dinov2-vit-base-p14_cls@blocks.2.norm2\", \"dinov2-vit-base-p14_cls@blocks.3.norm2\", \"dinov2-vit-base-p14_cls@blocks.4.norm2\", \"dinov2-vit-base-p14_cls@blocks.5.norm2\", \"dinov2-vit-base-p14_cls@blocks.6.norm2\", \"dinov2-vit-base-p14_cls@blocks.7.norm2\", \"dinov2-vit-base-p14_cls@blocks.8.norm2\", \"dinov2-vit-base-p14_cls@blocks.9.norm2\", \"dinov2-vit-base-p14_cls@norm\"]'\n",
    "]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9aefb7ee-8c3e-421e-bbd6-71a8f7fc2899",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(ablation_results.shape)\n",
    "ablation_results = ablation_results[\n",
    "    (ablation_results['model_ids'].isin(allowed_mids)) & \\\n",
    "    (ablation_results['dataset'].isin(allowed_ds))\n",
    "].copy().reset_index(drop=True)\n",
    "print(ablation_results.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "518b27dd-1163-49b4-af26-238759d41668",
   "metadata": {},
   "outputs": [],
   "source": [
    "linear_probes = ablation_results[ablation_results['task'] == 'linear_probe'].copy().reset_index(drop=True)\n",
    "all_attn_results =  ablation_results[ablation_results['task'] == 'attentive_probe'].copy().reset_index(drop=True)\n",
    "\n",
    "linear_probes.shape, all_attn_results.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8ddfa911-1dde-4d6c-879f-01b01de360ec",
   "metadata": {},
   "outputs": [],
   "source": [
    "ref_single_model_probes = linear_probes[linear_probes['experiment'] == 'results_ref_point_branch_position_embedd'].copy().reset_index(drop=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b1a9fb3e-b93e-4a54-84a3-746616a1c45f",
   "metadata": {},
   "outputs": [],
   "source": [
    "new_all_attn_results = []\n",
    "for (exp_name, use_scheduler), exp_data in all_attn_results.groupby(['experiment', 'use_scheduler'], dropna=False):\n",
    "    ref_exp = exp_name if exp_name != \"results_ablation_2\" else \"results_ablation\"\n",
    "    print(exp_name, use_scheduler, ref_exp)\n",
    "    if np.isnan(use_scheduler):\n",
    "        ref_linear_probes = linear_probes[linear_probes['experiment'] == ref_exp].copy().reset_index(drop=True)\n",
    "    else:\n",
    "        ref_linear_probes = linear_probes[\n",
    "            (linear_probes['experiment'] == ref_exp) & \\\n",
    "            (linear_probes['use_scheduler'] == use_scheduler)\n",
    "            ].copy().reset_index(drop=True)\n",
    "    print(exp_data.shape, ref_linear_probes.shape)\n",
    "    new_all_attn_results.append(get_relative_performances(exp_data.reset_index(drop=True), ref_linear_probes))\n",
    "all_attn_results = pd.concat(new_all_attn_results).reset_index(drop=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "06192b6f-98a8-49bd-b1aa-035030fbe6e2",
   "metadata": {},
   "outputs": [],
   "source": [
    "all_attn_results.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4f444112-385b-45f6-b2e1-92dbd1e24393",
   "metadata": {},
   "outputs": [],
   "source": [
    "linear_probes['experiment'].value_counts()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d1fe1359-15a8-4b30-afe4-49e0d243e0f9",
   "metadata": {},
   "outputs": [],
   "source": [
    "# all_attn_results = all_attn_results.drop(index=[2])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ebc545c9-aee7-4c46-9e8c-1e1076776416",
   "metadata": {},
   "source": [
    "## Single model eval: cosine schedule, early stopping (no cosine schedule), input jitter"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4c0e8d14-342c-4507-bca6-74d0c249a860",
   "metadata": {},
   "outputs": [],
   "source": [
    "linear_probes['experiment'].value_counts()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b0bb2fe0-8315-4dee-bf41-4a5919566d28",
   "metadata": {},
   "outputs": [],
   "source": [
    "tmp_idx = linear_probes['experiment']=='results_ref_point_branch_position_embedd'\n",
    "linear_probes.loc[tmp_idx, 'experiment'] = \"Reference results\" \n",
    "linear_probes['experiment'].value_counts()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ac88252a-e31a-4e91-9e85-62133a3ee740",
   "metadata": {},
   "outputs": [],
   "source": [
    "tmp_idx = linear_probes['experiment']=='results_ablation'\n",
    "linear_probes.loc[tmp_idx, 'experiment'] = \"40_epochs_cosine_scheduling\" \n",
    "linear_probes['experiment'].value_counts()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fba156bb-9002-4f04-872a-d470d497a55a",
   "metadata": {},
   "outputs": [],
   "source": [
    "tmp_idx = linear_probes['experiment']=='results_ablation_3'\n",
    "linear_probes.loc[tmp_idx, 'experiment'] = \"early_stopping_use_scheduler_\" + linear_probes.loc[tmp_idx, 'use_scheduler'].astype(str) \n",
    "linear_probes['experiment'].value_counts()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b196a64d-4549-4697-abdc-3ac2c962b52c",
   "metadata": {},
   "outputs": [],
   "source": [
    "# linear_probes.loc[linear_probes['experiment']=='results_with_jitter', 'experiment'] = linear_probes.loc[linear_probes['experiment']=='results_with_jitter', 'experiment'] + linear_probes.loc[linear_probes['experiment']=='results_with_jitter', 'jitter_p'].astype(str) \n",
    "# linear_probes.loc[linear_probes['experiment']=='results_ref_point_branch_position_embedd', 'experiment'] = \"Reference results.\"\n",
    "# tmp_early_stopping = linear_probes['experiment']=='results_early_stopping'\n",
    "# linear_probes.loc[tmp_early_stopping, 'experiment'] = linear_probes.loc[tmp_early_stopping, 'experiment'] + \"_min_delta_abs_\" + linear_probes.loc[tmp_early_stopping, 'min_delta'].astype(str) + \"_accuracy\"\n",
    "# tmp_early_stopping = linear_probes['experiment']=='results_early_stopping_pct_delta'\n",
    "# linear_probes.loc[tmp_early_stopping, 'experiment'] = linear_probes.loc[tmp_early_stopping, 'experiment'] + \"_min_delta_pct_\" + linear_probes.loc[tmp_early_stopping, 'min_delta'].astype(str) + \"_\" + linear_probes.loc[tmp_early_stopping, 'optim_metric'].astype(str)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f2dfb927-df69-4f18-a9e9-fe2297641b5e",
   "metadata": {},
   "outputs": [],
   "source": [
    "metric_columns = [\n",
    "    'train_lp_bal_acc1', 'test_lp_bal_acc1', \n",
    "]\n",
    "\n",
    "for metric in metric_columns:\n",
    "    print(f\"Metric: {metric}\")\n",
    "    print(\"=\" * 50)\n",
    "    \n",
    "    # Create pivot table\n",
    "    tmp = pd.pivot(\n",
    "        linear_probes,\n",
    "        index='dataset',\n",
    "        columns='experiment',\n",
    "        values=metric\n",
    "    )\n",
    "    \n",
    "    # Sort and transpose\n",
    "    tmp_sorted = tmp.sort_index().T\n",
    "    \n",
    "    # Apply background gradient per column (which are datasets after transpose)\n",
    "    styled_tmp = tmp_sorted.style.background_gradient(\n",
    "        cmap='Greens',  # You can change this to other colormaps like 'RdYlBu', 'coolwarm', etc.\n",
    "        axis=0,          # Apply gradient per column (0 = along rows within each column)\n",
    "        subset=None      # Apply to all columns\n",
    "    ).format(precision=4)  # Optional: format numbers to 4 decimal places\n",
    "    \n",
    "    display(styled_tmp)\n",
    "    print(\"\\n\")  # Add space between metrics"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fee253b0-84f1-4f69-b6a5-8865f1f54480",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "color_mapping = {\n",
    "    'Reference results': 0,     \n",
    "    '40_epochs_cosine_scheduling': 4,\n",
    "    'early_stopping_use_scheduler_False': 8,\n",
    "    'early_stopping_use_scheduler_True': 9,\n",
    "}\n",
    "cmap = plt.get_cmap('tab20c')\n",
    "color_dict = {label: cmap(idx) for label, idx in color_mapping.items()}\n",
    "\n",
    "\n",
    "for metric in metric_columns:\n",
    "    print(f\"Metric: {metric}\")\n",
    "    g = sns.catplot(\n",
    "        linear_probes.sort_values('experiment'),\n",
    "        x = 'experiment',\n",
    "        order = list(color_mapping.keys()),\n",
    "        hue = 'experiment',\n",
    "        y = metric,\n",
    "        col = 'dataset',\n",
    "        kind='bar',\n",
    "        col_wrap=3,\n",
    "        height=3,\n",
    "        aspect = 1.25,\n",
    "        sharey=False,\n",
    "        palette=color_dict\n",
    "    )\n",
    "    g.set_titles(\"{col_name}\")\n",
    "    g.set_xlabels(\"\")\n",
    "    g.map_dataframe(set_ylims_with_margin, metric=metric,  margin_percent=0.05)\n",
    "    g.map(lambda *args, **kwargs: plt.xticks(rotation=45, ha='right'))\n",
    "\n",
    "    g.fig.tight_layout()\n",
    "    g.fig.subplots_adjust(hspace=.2)\n",
    "    path = storing_path / f\"single_model_early_stopping_only_{metric}.pdf\"\n",
    "    save_or_show(g.fig, path, SAVE)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ec636fe6-1333-4ca5-aafe-82c93272722a",
   "metadata": {},
   "source": [
    "**Note: Early stopping is usually on par with the reference results. Addint N(0, 0.05) jitter to input with probability of 0.5 seems to sometimes improve sometimes decrease performance**"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "54236457-d430-4767-abd2-476e75bdb6b8",
   "metadata": {},
   "source": [
    "## Attention probe evaluation: early stopping (no cosine schedule), input dim\n",
    "\n",
    "We are comparing the results with the setting:\n",
    "- dinov2-base\n",
    "- dim = 2*max_dim\n",
    "- Reference setup uses cosine scheduling, no affine transform in batch norm, 40 epochs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3df19a61-35da-47f2-ad3b-b8f5f780d64c",
   "metadata": {},
   "outputs": [],
   "source": [
    "all_attn_results[['experiment', 'nr_layers', 'num_heads']].value_counts().sort_index()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b8400fe4-1a6d-4e36-9648-3abc49d04722",
   "metadata": {},
   "outputs": [],
   "source": [
    "df1 = all_attn_results[all_attn_results['experiment']=='results_ablation'].copy()\n",
    "df1['experiment'] = \"40_epochs_cosine_scheduling (input_dim=dim=attn_dim)\"\n",
    "df1.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ccb0b405-dcbd-46db-a1ad-2b3b97c1240f",
   "metadata": {},
   "outputs": [],
   "source": [
    "df2 = all_attn_results[all_attn_results['experiment']=='results_ablation_2'].copy()\n",
    "df2['experiment'] = \"40_epochs_cosine_scheduling (input_dim=dim, head_dim=2*(dim/num_heads))\"\n",
    "df2.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bdd6afaf-9735-4643-838c-54809a11d163",
   "metadata": {},
   "outputs": [],
   "source": [
    "df3 = all_attn_results[all_attn_results['experiment']=='results_ablation_3'].copy()\n",
    "df3['experiment'] = \"early_stopping_use_scheduler_\" + df3['use_scheduler'].astype(str) \n",
    "df3.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "93193f04-60cf-4b28-93d4-8e0eaeb878b1",
   "metadata": {},
   "outputs": [],
   "source": [
    "df_ref = all_attn_results[all_attn_results['experiment']=='results_ref_point_branch_position_embedd']\n",
    "df_ref = df_ref[df_ref['nr_layers'].isin(df1['nr_layers'].unique())].copy()\n",
    "df_ref = df_ref[df_ref['num_heads'].isin(df1['num_heads'].unique())].copy()\n",
    "df_ref['experiment'] = \"Reference results. (dim=\" + df_ref['dim'].astype(str) + \")\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "887dcc02-a36f-4755-92f8-85ad3f35abe0",
   "metadata": {},
   "outputs": [],
   "source": [
    "df = pd.concat([df1, df2, df3, df_ref]).reset_index(drop=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "793a504c-989b-4268-8f37-587193ae2baf",
   "metadata": {},
   "outputs": [],
   "source": [
    "sorted(df['experiment'].unique())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9a35d41e-9f1d-4cb2-8d28-96260f8158ee",
   "metadata": {},
   "outputs": [],
   "source": [
    "color_mapping = {\n",
    "    'Reference results. (dim=768.0)': 0,           # Base color 1\n",
    "    'Reference results. (dim=1536.0)': 1,      # Base color 2\n",
    "    '40_epochs_cosine_scheduling (input_dim=dim=attn_dim)': 4,      # Base color 2\n",
    "    '40_epochs_cosine_scheduling (input_dim=dim, head_dim=2*(dim/num_heads))': 5,      # Base color 3\n",
    "    'early_stopping_use_scheduler_False': 8,  # Base color 4\n",
    "    'early_stopping_use_scheduler_True': 9,  # Adjacent to grad_norm_clip1.0\n",
    "}\n",
    "cmap = plt.cm.get_cmap('tab20c')\n",
    "color_dict = {label: cmap(idx) for label, idx in color_mapping.items()}\n",
    "\n",
    "for metric in ['train_lp_bal_acc1','test_lp_bal_acc1', 'relative_test_lp_bal_acc1']:\n",
    "    print(metric)\n",
    "    g = sns.catplot(\n",
    "        df,\n",
    "        x = 'experiment',\n",
    "        order = list(color_mapping.keys()),\n",
    "        y = metric,\n",
    "        col = 'dataset',\n",
    "        col_wrap=3,\n",
    "        hue = 'experiment',\n",
    "        kind='bar',\n",
    "        height=3.5,\n",
    "        aspect = 2,\n",
    "        sharey=False,\n",
    "        palette=color_dict\n",
    "    )\n",
    "    g.set_titles(\"{col_name}\")\n",
    "    g.set_xlabels(\"\")\n",
    "    g.map_dataframe(set_ylims_with_margin, metric=metric,  margin_percent=0.05)\n",
    "    g.map(lambda *args, **kwargs: plt.xticks(rotation=45, ha='right'))\n",
    "\n",
    "    def get_best_linear_probe_val(data, *args, **kwargs):\n",
    "        dataset = data[\"dataset\"].unique()[0]\n",
    "        best_linear_probe = linear_probes.loc[linear_probes[\"dataset\"]==dataset,:]\n",
    "        if 'relative' not in metric:\n",
    "            # Get top 2 values\n",
    "            top_2_idx = best_linear_probe[metric].nlargest(2).index\n",
    "            \n",
    "            ax = plt.gca()\n",
    "            colors = [\"r\", \"orange\"]  # Different colors for top 2\n",
    "            linestyles = [\":\", \"--\"]  # Different line styles\n",
    "            \n",
    "            for i, idx in enumerate(top_2_idx):\n",
    "                value = best_linear_probe.loc[idx, metric]\n",
    "                experiment = best_linear_probe.loc[idx, \"experiment\"]\n",
    "                print(f\"Top {i+1}: {metric}, {experiment}, {value}\")\n",
    "                ax.axhline(value, c=colors[i], ls=linestyles[i], zorder=-1, \n",
    "                          label=f\"Top {i+1}: {experiment}\")\n",
    "            print()\n",
    "        \n",
    "    g.map_dataframe(get_best_linear_probe_val)\n",
    "    g.fig.tight_layout()\n",
    "    g.fig.subplots_adjust(hspace=.2)\n",
    "\n",
    "    # sns.move_legend(g, loc=\"upper left\", bbox_to_anchor=(0.7,0.4), ncol=2)\n",
    "    path = storing_path / f\"attn_probe_early_stopping_n_dim_only_{metric}.pdf\"\n",
    "    save_or_show(g.fig, path, SAVE)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "10e3f899-2bf8-4faa-a937-ac238c6bc48d",
   "metadata": {},
   "outputs": [],
   "source": [
    "# color_mapping_old = {\n",
    "#     '40_epochs_cosine_scheduling (input_dim=dim=attn_dim)': 4,      # Base color 2\n",
    "#     '40_epochs_cosine_scheduling (input_dim=dim, head_dim=2*(dim/num_heads))': 5,      # Base color 3\n",
    "# }\n",
    "# color_mapping_new = {\n",
    "#     'input_dim / num_heads': 4,      # Base color 2\n",
    "#     '2 * (input_dim / num_heads)': 5,      # Base color 3\n",
    "# }\n",
    "# cmap = plt.cm.get_cmap('tab20c')\n",
    "# color_dict = {label: cmap(idx) for label, idx in color_mapping_new.items()}\n",
    "\n",
    "# exp_name_mapping = {n1:n2 for n1, n2 in zip(color_mapping_old.keys(), color_mapping_new.keys())}\n",
    "# df_subset = df[df['experiment'].isin(color_mapping_old.keys())].copy()\n",
    "# df_subset['experiment'] = df_subset['experiment'].map(exp_name_mapping)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "61b12cc3-5f5d-4296-8201-a6000268e10b",
   "metadata": {},
   "outputs": [],
   "source": [
    "# color_mapping_old = {\n",
    "#     '40_epochs_cosine_scheduling (input_dim=dim=attn_dim)': 4,      # Base color 2\n",
    "#     '40_epochs_cosine_scheduling (input_dim=dim, head_dim=2*(dim/num_heads))': 5,      # Base color 3\n",
    "# }\n",
    "# color_mapping_new = {\n",
    "#     'input_dim / num_heads': 4,      # Base color 2\n",
    "#     '2 * (input_dim / num_heads)': 5,      # Base color 3\n",
    "# }\n",
    "# cmap = plt.cm.get_cmap('tab20c')\n",
    "# color_dict = {label: cmap(idx) for label, idx in color_mapping_new.items()}\n",
    "\n",
    "# exp_name_mapping = {n1:n2 for n1, n2 in zip(color_mapping_old.keys(), color_mapping_new.keys())}\n",
    "# df_subset = df[df['experiment'].isin(color_mapping_old.keys())].copy()\n",
    "# df_subset['experiment'] = df_subset['experiment'].map(exp_name_mapping)\n",
    "\n",
    "# for metric in ['train_lp_bal_acc1','test_lp_bal_acc1', 'relative_test_lp_bal_acc1']:\n",
    "#     print(metric)\n",
    "#     g = sns.catplot(\n",
    "#         df_subset,\n",
    "#         x = 'experiment',\n",
    "#         order = list(color_mapping_new.keys()),\n",
    "#         y = metric,\n",
    "#         col = 'dataset',\n",
    "#         col_wrap=3,\n",
    "#         hue = 'experiment',\n",
    "#         kind='bar',\n",
    "#         height=3.5,\n",
    "#         aspect = 1.1,\n",
    "#         sharey=False,\n",
    "#         palette=color_dict\n",
    "#     )\n",
    "#     g.set_titles(\"{col_name}\")\n",
    "#     g.set_xlabels(\"\")\n",
    "#     # g.map_dataframe(set_ylims_with_margin, metric=metric,  margin_percent=0.1)\n",
    "#     g.map(lambda *args, **kwargs: plt.xticks(rotation=45, ha='right'))\n",
    "\n",
    "#     def get_best_linear_probe_val(data, *args, **kwargs):\n",
    "#         dataset = data[\"dataset\"].unique()[0]\n",
    "#         best_linear_probe = linear_probes.loc[linear_probes[\"dataset\"]==dataset,:]\n",
    "#         if 'relative' not in metric:\n",
    "#             # Get top 2 values\n",
    "#             top_2_idx = best_linear_probe[metric].nlargest(2).index\n",
    "            \n",
    "#             ax = plt.gca()\n",
    "#             colors = [\"r\", \"orange\"]  # Different colors for top 2\n",
    "#             linestyles = [\":\", \"--\"]  # Different line styles\n",
    "            \n",
    "#             for i, idx in enumerate(top_2_idx):\n",
    "#                 value = best_linear_probe.loc[idx, metric]\n",
    "#                 experiment = best_linear_probe.loc[idx, \"experiment\"]\n",
    "#                 print(f\"Top {i+1}: {metric}, {experiment}, {value}\")\n",
    "#                 ax.axhline(value, c=colors[i], ls=linestyles[i], zorder=-1, \n",
    "#                           label=f\"Top {i+1}: {experiment}\")\n",
    "#             print()\n",
    "#             if dataset != \"wds/country211\":\n",
    "#                 ax.set_ylim(0.85 if \"train\" in metric else 0.75, 1)\n",
    "        \n",
    "#     g.map_dataframe(get_best_linear_probe_val)\n",
    "#     g.fig.tight_layout()\n",
    "#     g.fig.subplots_adjust(hspace=.2)\n",
    "\n",
    "#     # sns.move_legend(g, loc=\"upper left\", bbox_to_anchor=(0.7,0.4), ncol=2)\n",
    "#     path = storing_path / f\"attn_probe_n_dim_only_{metric}.pdf\"\n",
    "#     save_or_show(g.fig, path, SAVE)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d78fc3f9-cd15-416e-8804-2cd67c001130",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.18"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
