{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Data Analysis of Similarity on Inverted Images with Varying Dataset Size and Robustness"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 136,
   "metadata": {},
   "outputs": [],
   "source": [
    "palette = {\"inverted\": \"C0\", \"standard\": \"C1\"}\n",
    "palette = {\"inverted\": \"magenta\", \"standard\": \"cyan\"}\n",
    "palette = {\"inverted\": \"plum\", \"standard\": \"darkturquoise\"}\n",
    "\n",
    "color_kwargs=dict(hue=\"input\", palette=palette, legend=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "results_base_dir = \"/root/univ-data/results\"\n",
    "\n",
    "results = {\n",
    "    \"imagenet50\": {\n",
    "        \"standard\":{\n",
    "            \"dis\": {\n",
    "                \"eps0\":f\"{results_base_dir}/imagenet50/standard/dis_eps0_0.csv\",\n",
    "                \"eps025\":f\"{results_base_dir}/imagenet50/standard/dis_eps025_0.csv\",\n",
    "                \"eps05\":f\"{results_base_dir}/imagenet50/standard/dis_eps05_0.csv\",\n",
    "                \"eps1\":f\"{results_base_dir}/imagenet50/standard/dis_eps1_0.csv\",\n",
    "                \"eps3\":f\"{results_base_dir}/imagenet50/standard/dis_eps3_0.csv\",\n",
    "            },\n",
    "            \"jsd\": {\n",
    "                \"eps0\":f\"{results_base_dir}/imagenet50/standard/jsd_eps0_0.csv\",\n",
    "                \"eps025\":f\"{results_base_dir}/imagenet50/standard/jsd_eps025_0.csv\",\n",
    "                \"eps05\":f\"{results_base_dir}/imagenet50/standard/jsd_eps05_0.csv\",\n",
    "                \"eps1\":f\"{results_base_dir}/imagenet50/standard/jsd_eps1_0.csv\",\n",
    "                \"eps3\":f\"{results_base_dir}/imagenet50/standard/jsd_eps3_0.csv\",\n",
    "            }\n",
    "        },\n",
    "        \"inverted\": {\n",
    "            \"dis\": {\n",
    "                \"eps0\":f\"{results_base_dir}/imagenet50/dis_eps0_0.csv\",\n",
    "                \"eps025\":f\"{results_base_dir}/imagenet50/dis_eps025_0.csv\",\n",
    "                \"eps05\":f\"{results_base_dir}/imagenet50/dis_eps05_0.csv\",\n",
    "                \"eps1\":f\"{results_base_dir}/imagenet50/dis_eps1_0.csv\",\n",
    "                \"eps3\":f\"{results_base_dir}/imagenet50/dis_eps3_0.csv\",\n",
    "            },\n",
    "            \"jsd\": {\n",
    "                \"eps0\":f\"{results_base_dir}/imagenet50/jsd_eps0_0.csv\",\n",
    "                \"eps025\":f\"{results_base_dir}/imagenet50/jsd_eps025_0.csv\",\n",
    "                \"eps05\":f\"{results_base_dir}/imagenet50/jsd_eps05_0.csv\",\n",
    "                \"eps1\":f\"{results_base_dir}/imagenet50/jsd_eps1_0.csv\",\n",
    "                \"eps3\":f\"{results_base_dir}/imagenet50/jsd_eps3_0.csv\",\n",
    "            },\n",
    "            \"cka\": {\n",
    "                \"eps0\":f\"{results_base_dir}/imagenet50/cka_eps0_10.csv\",\n",
    "                \"eps025\":f\"{results_base_dir}/imagenet50/cka_eps025_10.csv\",\n",
    "                \"eps05\":f\"{results_base_dir}/imagenet50/cka_eps05_10.csv\",\n",
    "                \"eps1\":f\"{results_base_dir}/imagenet50/cka_eps1_10.csv\",\n",
    "                \"eps3\":f\"{results_base_dir}/imagenet50/cka_eps3_10.csv\",\n",
    "            },\n",
    "            \"2ndcos\": {\n",
    "                \"eps0\":f\"{results_base_dir}/imagenet50/cos_sim_mean_eps0_10.csv\",\n",
    "                \"eps025\":f\"{results_base_dir}/imagenet50/cos_sim_mean_eps025_10.csv\",\n",
    "                \"eps05\":f\"{results_base_dir}/imagenet50/cos_sim_mean_eps05_10.csv\",\n",
    "                \"eps1\":f\"{results_base_dir}/imagenet50/cos_sim_mean_eps1_10.csv\",\n",
    "                \"eps3\":f\"{results_base_dir}/imagenet50/cos_sim_mean_eps3_10.csv\",\n",
    "            },\n",
    "            \"jaccard\": {\n",
    "                \"eps0\":f\"{results_base_dir}/imagenet50/jac_mean_eps0_10.csv\",\n",
    "                \"eps025\":f\"{results_base_dir}/imagenet50/jac_mean_eps025_10.csv\",\n",
    "                \"eps05\":f\"{results_base_dir}/imagenet50/jac_mean_eps05_10.csv\",\n",
    "                \"eps1\":f\"{results_base_dir}/imagenet50/jac_mean_eps1_10.csv\",\n",
    "                \"eps3\":f\"{results_base_dir}/imagenet50/jac_mean_eps3_10.csv\",\n",
    "            },\n",
    "            \"proc\": {\n",
    "                \"eps0\":f\"{results_base_dir}/imagenet50/proc_norm_mean_eps0_10.csv\",\n",
    "                \"eps025\":f\"{results_base_dir}/imagenet50/proc_norm_mean_eps025_10.csv\",\n",
    "                \"eps05\":f\"{results_base_dir}/imagenet50/proc_norm_mean_eps05_10.csv\",\n",
    "                \"eps1\":f\"{results_base_dir}/imagenet50/proc_norm_mean_eps1_10.csv\",\n",
    "                \"eps3\":f\"{results_base_dir}/imagenet50/proc_norm_mean_eps3_10.csv\",\n",
    "            },\n",
    "        }\n",
    "    },\n",
    "    \"imagenet100\": {\n",
    "        \"standard\":{\n",
    "            \"dis\": {\n",
    "                \"eps0\":f\"{results_base_dir}/imagenet100/standard/dis_eps0_0.csv\",\n",
    "                \"eps025\":f\"{results_base_dir}/imagenet100/standard/dis_eps025_0.csv\",\n",
    "                \"eps05\":f\"{results_base_dir}/imagenet100/standard/dis_eps05_0.csv\",\n",
    "                \"eps1\":f\"{results_base_dir}/imagenet100/standard/dis_eps1_0.csv\",\n",
    "                \"eps3\":f\"{results_base_dir}/imagenet100/standard/dis_eps3_0.csv\",\n",
    "            },\n",
    "            \"jsd\": {\n",
    "                \"eps0\":f\"{results_base_dir}/imagenet100/standard/jsd_eps0_0.csv\",\n",
    "                \"eps025\":f\"{results_base_dir}/imagenet100/standard/jsd_eps025_0.csv\",\n",
    "                \"eps05\":f\"{results_base_dir}/imagenet100/standard/jsd_eps05_0.csv\",\n",
    "                \"eps1\":f\"{results_base_dir}/imagenet100/standard/jsd_eps1_0.csv\",\n",
    "                \"eps3\":f\"{results_base_dir}/imagenet100/standard/jsd_eps3_0.csv\",\n",
    "            }\n",
    "        },\n",
    "        \"inverted\": {\n",
    "            \"dis\": {\n",
    "                \"eps0\":f\"{results_base_dir}/imagenet100/dis_eps0_0.csv\",\n",
    "                \"eps025\":f\"{results_base_dir}/imagenet100/dis_eps025_0.csv\",\n",
    "                \"eps05\":f\"{results_base_dir}/imagenet100/dis_eps05_0.csv\",\n",
    "                \"eps1\":f\"{results_base_dir}/imagenet100/dis_eps1_0.csv\",\n",
    "                \"eps3\":f\"{results_base_dir}/imagenet100/dis_eps3_0.csv\",\n",
    "            },\n",
    "            \"jsd\": {\n",
    "                \"eps0\":f\"{results_base_dir}/imagenet100/jsd_eps0_0.csv\",\n",
    "                \"eps025\":f\"{results_base_dir}/imagenet100/jsd_eps025_0.csv\",\n",
    "                \"eps05\":f\"{results_base_dir}/imagenet100/jsd_eps05_0.csv\",\n",
    "                \"eps1\":f\"{results_base_dir}/imagenet100/jsd_eps1_0.csv\",\n",
    "                \"eps3\":f\"{results_base_dir}/imagenet100/jsd_eps3_0.csv\",\n",
    "            },\n",
    "            \"cka\": {\n",
    "                \"eps0\":f\"{results_base_dir}/imagenet100/cka_eps0_0.csv\",\n",
    "                \"eps025\":f\"{results_base_dir}/imagenet100/cka_eps025_0.csv\",\n",
    "                \"eps05\":f\"{results_base_dir}/imagenet100/cka_eps05_100.csv\",\n",
    "                \"eps1\":f\"{results_base_dir}/imagenet100/cka_eps1_0.csv\",\n",
    "                \"eps3\":f\"{results_base_dir}/imagenet100/cka_eps3_0.csv\",\n",
    "            },\n",
    "            \"2ndcos\": {\n",
    "                \"eps0\":f\"{results_base_dir}/imagenet100/cos_sim_mean_eps0_0.csv\",\n",
    "                \"eps025\":f\"{results_base_dir}/imagenet100/cos_sim_mean_eps025_0.csv\",\n",
    "                \"eps05\":f\"{results_base_dir}/imagenet100/cos_sim_mean_eps05_100.csv\",\n",
    "                \"eps1\":f\"{results_base_dir}/imagenet100/cos_sim_mean_eps1_0.csv\",\n",
    "                \"eps3\":f\"{results_base_dir}/imagenet100/cos_sim_mean_eps3_0.csv\",\n",
    "            },\n",
    "            \"jaccard\": {\n",
    "                # \"eps0\":f\"{results_base_dir}/imagenet100/jac_mean_eps0_100.csv\",\n",
    "                # \"eps025\":f\"{results_base_dir}/imagenet100/jac_mean_eps025_100.csv\",\n",
    "                # \"eps05\":f\"{results_base_dir}/imagenet100/jac_mean_eps05_100.csv\",\n",
    "                # \"eps1\":f\"{results_base_dir}/imagenet100/jac_mean_eps1_100.csv\",\n",
    "                # \"eps3\":f\"{results_base_dir}/imagenet100/jac_mean_eps3_100.csv\",\n",
    "                \"eps0\":f\"{results_base_dir}/imagenet100/jac_mean_eps0_10.csv\",\n",
    "                \"eps025\":f\"{results_base_dir}/imagenet100/jac_mean_eps025_10.csv\",\n",
    "                \"eps05\":f\"{results_base_dir}/imagenet100/jac_mean_eps05_10.csv\",\n",
    "                \"eps1\":f\"{results_base_dir}/imagenet100/jac_mean_eps1_10.csv\",\n",
    "                \"eps3\":f\"{results_base_dir}/imagenet100/jac_mean_eps3_10.csv\",\n",
    "            },\n",
    "            \"proc\": {\n",
    "                \"eps0\":f\"{results_base_dir}/imagenet100/proc_norm_mean_eps0_0.csv\",\n",
    "                \"eps025\":f\"{results_base_dir}/imagenet100/proc_norm_mean_eps025_0.csv\",\n",
    "                \"eps05\":f\"{results_base_dir}/imagenet100/proc_norm_mean_eps05_100.csv\",\n",
    "                \"eps1\":f\"{results_base_dir}/imagenet100/proc_norm_mean_eps1_0.csv\",\n",
    "                \"eps3\":f\"{results_base_dir}/imagenet100/proc_norm_mean_eps3_0.csv\",\n",
    "            },\n",
    "        }\n",
    "    },\n",
    "    \"imagenet1k\": {\n",
    "        \"inverted\": {\n",
    "            \"dis\": {\n",
    "                \"eps0\":f\"{results_base_dir}/imagenet1k/dis_eps0_1.csv\",\n",
    "                \"eps025\":f\"{results_base_dir}/imagenet1k/dis_eps025_1.csv\",\n",
    "                \"eps05\":f\"{results_base_dir}/imagenet1k/dis_eps05_1.csv\",\n",
    "                \"eps1\":f\"{results_base_dir}/imagenet1k/dis_eps1_1.csv\",\n",
    "                \"eps3\":f\"{results_base_dir}/imagenet1k/dis_eps3_1.csv\",\n",
    "            },\n",
    "            \"jsd\": {\n",
    "                \"eps0\":f\"{results_base_dir}/imagenet1k/jsd_eps0_1.csv\",\n",
    "                \"eps025\":f\"{results_base_dir}/imagenet1k/jsd_eps025_1.csv\",\n",
    "                \"eps05\":f\"{results_base_dir}/imagenet1k/jsd_eps05_1.csv\",\n",
    "                \"eps1\":f\"{results_base_dir}/imagenet1k/jsd_eps1_1.csv\",\n",
    "                \"eps3\":f\"{results_base_dir}/imagenet1k/jsd_eps3_1.csv\",\n",
    "            },\n",
    "            \"cka\": {\n",
    "                \"eps0\":f\"{results_base_dir}/imagenet1k/cka_eps0_0.csv\",\n",
    "                \"eps025\":f\"{results_base_dir}/imagenet1k/cka_eps025_0.csv\",\n",
    "                \"eps05\":f\"{results_base_dir}/imagenet1k/cka_eps05_0.csv\",\n",
    "                \"eps1\":f\"{results_base_dir}/imagenet1k/cka_eps1_0.csv\",\n",
    "                \"eps3\":f\"{results_base_dir}/imagenet1k/cka_eps3_0.csv\",\n",
    "            },\n",
    "            \"2ndcos\": {\n",
    "                \"eps0\":f\"{results_base_dir}/imagenet1k/cos_sim_mean_eps0_0.csv\",\n",
    "                \"eps025\":f\"{results_base_dir}/imagenet1k/cos_sim_mean_eps025_0.csv\",\n",
    "                \"eps05\":f\"{results_base_dir}/imagenet1k/cos_sim_mean_eps05_0.csv\",\n",
    "                \"eps1\":f\"{results_base_dir}/imagenet1k/cos_sim_mean_eps1_0.csv\",\n",
    "                \"eps3\":f\"{results_base_dir}/imagenet1k/cos_sim_mean_eps3_0.csv\",\n",
    "            },\n",
    "            \"jaccard\": {\n",
    "                \"eps0\":f\"{results_base_dir}/imagenet1k/jac_mean_eps0_10.csv\",\n",
    "                \"eps025\":f\"{results_base_dir}/imagenet1k/jac_mean_eps025_10.csv\",\n",
    "                \"eps05\":f\"{results_base_dir}/imagenet1k/jac_mean_eps05_10.csv\",\n",
    "                \"eps1\":f\"{results_base_dir}/imagenet1k/jac_mean_eps1_10.csv\",\n",
    "                \"eps3\":f\"{results_base_dir}/imagenet1k/jac_mean_eps3_10.csv\",\n",
    "            },\n",
    "            \"proc\": {\n",
    "                \"eps0\":f\"{results_base_dir}/imagenet1k/proc_norm_mean_eps0_0.csv\",\n",
    "                \"eps025\":f\"{results_base_dir}/imagenet1k/proc_norm_mean_eps025_0.csv\",\n",
    "                \"eps05\":f\"{results_base_dir}/imagenet1k/proc_norm_mean_eps05_0.csv\",\n",
    "                \"eps1\":f\"{results_base_dir}/imagenet1k/proc_norm_mean_eps1_0.csv\",\n",
    "                \"eps3\":f\"{results_base_dir}/imagenet1k/proc_norm_mean_eps3_0.csv\",\n",
    "            },\n",
    "        },\n",
    "        \"standard\": {\n",
    "            \"dis\": {\n",
    "                \"eps0\":f\"{results_base_dir}/imagenet1k/standard/dis_eps0_0.csv\",\n",
    "                \"eps025\":f\"{results_base_dir}/imagenet1k/standard/dis_eps025_0.csv\",\n",
    "                \"eps05\":f\"{results_base_dir}/imagenet1k/standard/dis_eps05_0.csv\",\n",
    "                \"eps1\":f\"{results_base_dir}/imagenet1k/standard/dis_eps1_0.csv\",\n",
    "                \"eps3\":f\"{results_base_dir}/imagenet1k/standard/dis_eps3_0.csv\",\n",
    "            },\n",
    "            \"jsd\": {\n",
    "                \"eps0\":f\"{results_base_dir}/imagenet1k/standard/jsd_eps0_1.csv\",\n",
    "                \"eps025\":f\"{results_base_dir}/imagenet1k/standard/jsd_eps025_1.csv\",\n",
    "                \"eps05\":f\"{results_base_dir}/imagenet1k/standard/jsd_eps05_1.csv\",\n",
    "                \"eps1\":f\"{results_base_dir}/imagenet1k/standard/jsd_eps1_1.csv\",\n",
    "                \"eps3\":f\"{results_base_dir}/imagenet1k/standard/jsd_eps3_1.csv\",\n",
    "            }\n",
    "        }\n",
    "    },\n",
    "    \"cifar10\": {\n",
    "        \"inverted\": {\n",
    "            \"dis\": {\n",
    "                \"eps0\": f\"{results_base_dir}/cifar10/dis_eps0_0.csv\",\n",
    "                \"eps025\": f\"{results_base_dir}/cifar10/dis_eps025_0.csv\",\n",
    "                \"eps05\": f\"{results_base_dir}/cifar10/dis_eps05_0.csv\",\n",
    "                \"eps1\": f\"{results_base_dir}/cifar10/dis_eps1_0.csv\",\n",
    "            },\n",
    "            \"jsd\": {\n",
    "                \"eps0\":f\"{results_base_dir}/cifar10/jsd_eps0_0.csv\",\n",
    "                \"eps025\":f\"{results_base_dir}/cifar10/jsd_eps025_0.csv\",\n",
    "                \"eps05\":f\"{results_base_dir}/cifar10/jsd_eps05_0.csv\",\n",
    "                \"eps1\":f\"{results_base_dir}/cifar10/jsd_eps1_0.csv\",\n",
    "            },\n",
    "            \"cka\": {\n",
    "                \"eps0\":f\"{results_base_dir}/cifar10/cka_eps0_0.csv\",\n",
    "                \"eps025\":f\"{results_base_dir}/cifar10/cka_eps025_0.csv\",\n",
    "                \"eps05\":f\"{results_base_dir}/cifar10/cka_eps05_0.csv\",\n",
    "                \"eps1\":f\"{results_base_dir}/cifar10/cka_eps1_0.csv\",\n",
    "            },\n",
    "            \"2ndcos\": {\n",
    "                \"eps0\":f\"{results_base_dir}/cifar10/cos_sim_mean_eps0_0.csv\",\n",
    "                \"eps025\":f\"{results_base_dir}/cifar10/cos_sim_mean_eps025_0.csv\",\n",
    "                \"eps05\":f\"{results_base_dir}/cifar10/cos_sim_mean_eps05_0.csv\",\n",
    "                \"eps1\":f\"{results_base_dir}/cifar10/cos_sim_mean_eps1_0.csv\",\n",
    "            },\n",
    "            \"jaccard\": {\n",
    "                \"eps0\":f\"{results_base_dir}/cifar10/jac_mean_eps0_10.csv\",\n",
    "                \"eps025\":f\"{results_base_dir}/cifar10/jac_mean_eps025_10.csv\",\n",
    "                \"eps05\":f\"{results_base_dir}/cifar10/jac_mean_eps05_10.csv\",\n",
    "                \"eps1\":f\"{results_base_dir}/cifar10/jac_mean_eps1_10.csv\",\n",
    "            },\n",
    "            \"proc\": {\n",
    "                \"eps0\":f\"{results_base_dir}/cifar10/proc_norm_mean_eps0_0.csv\",\n",
    "                \"eps025\":f\"{results_base_dir}/cifar10/proc_norm_mean_eps025_0.csv\",\n",
    "                \"eps05\":f\"{results_base_dir}/cifar10/proc_norm_mean_eps05_0.csv\",\n",
    "                \"eps1\":f\"{results_base_dir}/cifar10/proc_norm_mean_eps1_0.csv\",\n",
    "            },\n",
    "        },\n",
    "        \"standard\": {\n",
    "            \"dis\": {\n",
    "                \"eps0\": f\"{results_base_dir}/cifar10/standard/dis_eps0_0.csv\",\n",
    "                \"eps025\": f\"{results_base_dir}/cifar10/standard/dis_eps025_0.csv\",\n",
    "                \"eps05\": f\"{results_base_dir}/cifar10/standard/dis_eps05_0.csv\",\n",
    "                \"eps1\": f\"{results_base_dir}/cifar10/standard/dis_eps1_0.csv\",\n",
    "            }\n",
    "            ,\n",
    "            \"jsd\": {\n",
    "                \"eps0\":f\"{results_base_dir}/cifar10/standard/jsd_eps0_0.csv\",\n",
    "                \"eps025\":f\"{results_base_dir}/cifar10/standard/jsd_eps025_0.csv\",\n",
    "                \"eps05\":f\"{results_base_dir}/cifar10/standard/jsd_eps05_0.csv\",\n",
    "                \"eps1\":f\"{results_base_dir}/cifar10/standard/jsd_eps1_0.csv\",\n",
    "            }\n",
    "        }\n",
    "    },\n",
    "}\n",
    "\n",
    "\n",
    "import pandas as pd\n",
    "import seaborn as sns\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "\n",
    "sns.set_theme(\"paper\", style=\"darkgrid\", font_scale=1.5)\n",
    "\n",
    "\n",
    "def eps_to_float(eps: str) -> float:\n",
    "    if eps == \"eps0\":\n",
    "        return 0.0\n",
    "    elif eps == \"eps025\":\n",
    "        return 0.25\n",
    "    elif eps == \"eps05\":\n",
    "        return 0.5\n",
    "    elif eps == \"eps1\":\n",
    "        return 1.0\n",
    "    elif eps == \"eps3\":\n",
    "        return 3.0\n",
    "    else:\n",
    "        raise ValueError(f\"Eps level not recognized: {eps}\")\n",
    "\n",
    "dfs = []\n",
    "for dataset, dataset_results in results.items():\n",
    "    for input_type, input_type_results in dataset_results.items():\n",
    "        for measure, path_dict in input_type_results.items():\n",
    "            for eps, csv_path in path_dict.items():\n",
    "                try:\n",
    "                    df = pd.read_csv(csv_path, index_col=0)\n",
    "                except FileNotFoundError as e:\n",
    "                    print(e)\n",
    "                    continue\n",
    "\n",
    "                # Mask out comparisons between identical models, so we can drop them easier later\n",
    "                for index_val in df.index:\n",
    "                    # df.loc[index_val, index_val] = 1.0\n",
    "                    df.loc[index_val, index_val] = np.NaN\n",
    "\n",
    "                # Give name to index\n",
    "                df = df.rename_axis(\"model1\")\n",
    "\n",
    "                # Bring data to long format (many rows) and remove self comparisons\n",
    "                df = df.reset_index().melt(id_vars=\"model1\", var_name=\"model2\", value_name=\"score\")\n",
    "                df = df.dropna(axis=0)\n",
    "\n",
    "                df[\"measure\"] = measure\n",
    "                df[\"eps\"] = eps_to_float(eps)\n",
    "                df[\"dataset\"] = dataset\n",
    "                df[\"input\"] = input_type\n",
    "\n",
    "                if measure == \"proc\":\n",
    "                    df.loc[:, \"score\"] = (2 - df.loc[:, \"score\"]) / 2\n",
    "\n",
    "                if measure == \"dis\":\n",
    "                    df.loc[:, \"score\"] = 1 - df.loc[:, \"score\"]\n",
    "\n",
    "                if measure == \"jsd\":\n",
    "                    df.loc[:, \"score\"] = (np.log(2) - df.loc[:, \"score\"]) / np.log(2)\n",
    "\n",
    "                dfs.append(df)\n",
    "\n",
    "data = pd.concat(dfs, axis=0)\n",
    "\n",
    "# Exclude tiny_vit_5m results on ImageNet100 and 50, because we mistakingly used checkpoints for IN1k\n",
    "data = data.loc[~(((data.model1 == \"tiny_vit_5m\") | (data.model2 == \"tiny_vit5m\")) & (data.dataset.isin([\"imagenet100\", \"imagenet50\"])))]\n",
    "\n",
    "# Exclude broken densenet161 on IN1k\n",
    "data = data.loc[~(((data.model1 == \"densenet161\") | (data.model2 == \"densenet161\")) & (data.dataset.isin([\"imagenet1k\"])))]\n",
    "\n",
    "\n",
    "data.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sns.catplot(data=data, x=\"eps\", y=\"score\", hue=\"dataset\", kind=\"box\", row=\"measure\", col=\"input\")\n",
    "# sns.catplot(data=data, x=\"eps\", y=\"score\", hue=\"dataset\", kind=\"strip\", row=\"measure\", col=\"input\", alpha=0.4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 139,
   "metadata": {},
   "outputs": [],
   "source": [
    "# sns.catplot(data=data, x=\"dataset\", y=\"score\", hue=\"eps\", kind=\"box\", row=\"measure\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 140,
   "metadata": {},
   "outputs": [],
   "source": [
    "# sns.displot(data=data, x=\"score\", hue=\"dataset\", col=\"eps\", row=\"measure\", bins=20)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## procrustes clusters"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 141,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.cluster import KMeans\n",
    "\n",
    "scores_df = data.loc[(data.dataset == \"imagenet1k\") & (data.eps==3.0) & (data.measure == \"proc\"), :].copy()\n",
    "\n",
    "scores = scores_df[\"score\"].values.reshape(-1, 1)\n",
    "kmeans = KMeans(n_clusters=4, random_state=567, n_init=\"auto\").fit(scores)\n",
    "\n",
    "\n",
    "scores_df[\"cluster\"] = kmeans.labels_\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sns.catplot(data=scores_df, x=\"eps\", y=\"score\", hue=\"cluster\", palette=\"pastel\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from IPython.display import display\n",
    "\n",
    "for cluster in sorted(scores_df.cluster.unique()):\n",
    "    print(cluster)\n",
    "    display(scores_df[scores_df.cluster == cluster])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "data[(data.dataset==\"imagenet1k\") & (data.eps ==3) & (data.model1 == \"vgg16_bn\") & (data.model2 == \"tiny_vit_5m\")]\n",
    "data[(data.dataset==\"imagenet1k\") & (data.eps ==3) & (data.model1 == \"vgg16_bn\") & (data.model2 == \"tiny_vit_5m\")]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Repsim Boxplots"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### One Panel Per Metric Per Dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# sns.set_theme(\n",
    "#         style=\"whitegrid\",\n",
    "#         palette=\"colorblind\",\n",
    "#     )\n",
    "n_panels = 3\n",
    "\n",
    "models = [\"resnet18\", \"resnet50\", \"wide_resnet50_2\", \"wide_resnet50_4\", \"densenet161\", \"vgg16\", \"resnext50_32x4d\"]\n",
    "models = set(data.model1.unique()) | set(data.model2.unique())\n",
    "\n",
    "for dataset in [\"cifar10\", \"imagenet1k\"]:\n",
    "# dataset = \"cifar10\"\n",
    "# dataset = \"imagenet1k\"\n",
    "    input_type = \"inverted\"\n",
    "\n",
    "    fig, axes = plt.subplots(1, n_panels, figsize=(n_panels*3*1.61, 3))\n",
    "    ax_idx = 0\n",
    "\n",
    "\n",
    "    # plotdata = data.loc[(data.measure == \"cka\") & (data.dataset.isin([\"cifar10\", \"imagenet1k\"])) & (data.input == input_type) & data.model1.isin(models) & data.model2.isin(models)]\n",
    "    plotdata = data.loc[(data.measure == \"cka\") & (data.dataset == dataset) & (data.input == input_type) & data.model1.isin(models) & data.model2.isin(models)]\n",
    "    ax = axes[ax_idx]\n",
    "    sns.boxplot(plotdata, x=\"eps\", y=\"score\", ax=ax, **color_kwargs)\n",
    "    ax.set_title(\"CKA\")\n",
    "    ax.set_ylabel(\"Similarity\")\n",
    "    ax.set_xlabel(r\"Robustness $\\epsilon$\")\n",
    "    ax_idx += 1\n",
    "\n",
    "    plotdata = data.loc[(data.measure == \"proc\") & (data.dataset == dataset) & (data.input == input_type)& data.model1.isin(models) & data.model2.isin(models)]\n",
    "    ax = axes[ax_idx]\n",
    "    sns.boxplot(plotdata, x=\"eps\", y=\"score\", ax=ax, **color_kwargs)\n",
    "    ax.set_title(\"ProcrustesSim\")\n",
    "    ax.set_ylabel(\"\")\n",
    "    ax.set_xlabel(r\"Robustness $\\epsilon$\")\n",
    "    ax_idx += 1\n",
    "\n",
    "    plotdata = data.loc[(data.measure == \"jaccard\") & (data.dataset == dataset) & (data.input == input_type)& data.model1.isin(models) & data.model2.isin(models)]\n",
    "    ax = axes[ax_idx]\n",
    "    sns.boxplot(plotdata, x=\"eps\", y=\"score\", ax=ax, **color_kwargs)\n",
    "    ax.set_title(\"Jaccard\")\n",
    "    ax.set_ylabel(\"\")\n",
    "    ax.set_xlabel(r\"Robustness $\\epsilon$\")\n",
    "    ax_idx += 1\n",
    "\n",
    "    if dataset == \"imagenet1k\":\n",
    "        fig.suptitle(\"ImageNet1k\", y=1.06)\n",
    "    elif dataset == \"cifar10\":\n",
    "        fig.suptitle(\"CIFAR-10\", y=1.06)\n",
    "    # fig.tight_layout()\n",
    "    fig.savefig(f\"../figs/repsim_inverted_{dataset}.pdf\", bbox_inches=\"tight\")\n",
    "    fig"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### One Panel Per Metric"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "n_panels = 3\n",
    "\n",
    "models = [\"resnet18\", \"resnet50\", \"wide_resnet50_2\", \"wide_resnet50_4\", \"densenet161\", \"vgg16\", \"resnext50_32x4d\"]\n",
    "models = set(data.model1.unique()) | set(data.model2.unique())\n",
    "\n",
    "datasets = [\"cifar10\", \"imagenet1k\"]\n",
    "input_type = \"inverted\"\n",
    "\n",
    "fig, axes = plt.subplots(1, n_panels, figsize=(n_panels*3*1.61, 3))\n",
    "ax_idx = 0\n",
    "\n",
    "\n",
    "# plotdata = data.loc[(data.measure == \"cka\") & (data.dataset.isin([\"cifar10\", \"imagenet1k\"])) & (data.input == input_type) & data.model1.isin(models) & data.model2.isin(models)]\n",
    "plotdata = data.loc[(data.measure == \"cka\") & (data.dataset.isin(datasets)) & (data.input == input_type) & data.model1.isin(models) & data.model2.isin(models)]\n",
    "ax = axes[ax_idx]\n",
    "sns.boxplot(plotdata, x=\"dataset\", y=\"score\", ax=ax, hue=\"eps\", legend=False)\n",
    "ax.set_title(\"CKA\")\n",
    "ax.set_ylabel(\"Similarity\")\n",
    "ax.set_xlabel(r\"Dataset\")\n",
    "ax.set_xticklabels([\"ImageNet1k\", \"CIFAR-10\"])\n",
    "ax_idx += 1\n",
    "\n",
    "plotdata = data.loc[(data.measure == \"proc\") & (data.dataset.isin(datasets)) & (data.input == input_type)& data.model1.isin(models) & data.model2.isin(models)]\n",
    "ax = axes[ax_idx]\n",
    "sns.boxplot(plotdata, x=\"dataset\", y=\"score\", ax=ax, hue=\"eps\", legend=False)\n",
    "ax.set_title(\"ProcrustesSim\")\n",
    "ax.set_ylabel(\"\")\n",
    "ax.set_xlabel(r\"Dataset\")\n",
    "ax.set_xticklabels([\"ImageNet1k\", \"CIFAR-10\"])\n",
    "ax_idx += 1\n",
    "\n",
    "plotdata = data.loc[(data.measure == \"jaccard\") & (data.dataset.isin(datasets)) & (data.input == input_type)& data.model1.isin(models) & data.model2.isin(models)]\n",
    "ax = axes[ax_idx]\n",
    "sns.boxplot(plotdata, x=\"dataset\", y=\"score\", ax=ax, hue=\"eps\", legend=True)\n",
    "ax.set_title(\"Jaccard\")\n",
    "ax.set_ylabel(\"\")\n",
    "ax.set_xlabel(r\"Dataset\")\n",
    "ax.set_xticklabels([\"ImageNet1k\", \"CIFAR-10\"])\n",
    "sns.move_legend(ax, \"center\", bbox_to_anchor=(1.2,0.5))\n",
    "\n",
    "ax_idx += 1\n",
    "\n",
    "fig.savefig(f\"../figs/repsim_inverted_in1k_and_cifar10.pdf\", bbox_inches=\"tight\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Repsim Heatmaps"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df = data.copy()\n",
    "# Add back self-comparisons\n",
    "self_comps = []\n",
    "for dataset in df.dataset.unique():\n",
    "    for measure in df.measure.unique():\n",
    "        for eps in df.eps.unique():\n",
    "            for input_type in df.input.unique():\n",
    "                selection = df.loc[(df.dataset == dataset) & (df.measure == measure) & (df.eps == eps) & (df.input == input_type)]\n",
    "                for model in set(selection.model1.unique()) | set(selection.model2.unique()):\n",
    "                    self_comps.append((model, model, 1.0, measure, eps, dataset, input_type))\n",
    "self_comps = pd.DataFrame.from_records(self_comps, columns=df.columns)\n",
    "df = pd.concat((df, self_comps))\n",
    "print(df.tail())\n",
    "\n",
    "measures = [\"cka\", \"proc\", \"jaccard\"]\n",
    "models = [\"resnet18\", \"resnet50\", \"wide_resnet50_2\", \"wide_resnet50_4\", \"tiny_vit_5m\", \"vgg16_bn\"]\n",
    "# measures = [\"cka\", \"2ndcos\", \"proc\", \"jaccard\"]\n",
    "n_measures = len(measures)\n",
    "fig, axes = plt.subplots(n_measures, 5, figsize=(5*3, n_measures*3))\n",
    "for row_idx, measure in enumerate(measures):\n",
    "    for col_idx, eps in enumerate(sorted(df.eps.unique())):\n",
    "        ax = axes[row_idx, col_idx]\n",
    "        plotdf = df.loc[\n",
    "            (df.dataset == \"imagenet1k\") & (df.measure == measure) & (df.eps == eps) & (df.input == \"inverted\") &\n",
    "              (df.model1.isin(models)) &\n",
    "              (df.model2.isin(models))\n",
    "        ]\n",
    "        plotdf = plotdf.pivot(index=\"model1\", columns=\"model2\", values=\"score\")\n",
    "        sns.heatmap(\n",
    "            plotdf,\n",
    "            ax=ax,\n",
    "            vmin=0,\n",
    "            vmax=1,\n",
    "            annot=True,\n",
    "            fmt=\".1f\",\n",
    "            linewidths=0.5,\n",
    "            cbar=False,\n",
    "            # xticklabels=[\"\"] * len(plotdf.columns),\n",
    "            # yticklabels=[\"\"] * len(plotdf.columns),\n",
    "        )\n",
    "        if col_idx == 0:\n",
    "            ax.set_ylabel(measure)\n",
    "        else:\n",
    "            ax.set_ylabel(\"\")\n",
    "        if row_idx == 0:\n",
    "            ax.set_title(f\"{eps=}\")\n",
    "        ax.set_xlabel(\"\")\n",
    "\n",
    "# for measure in df.measure.unique():\n",
    "#     plotdf = df.loc[\n",
    "#         (df.dataset == \"imagenet1k\") & (df.measure == measure) & (df.eps == 3.0) & (df.input == \"inverted\")\n",
    "#     ]\n",
    "#     plotdf = plotdf.pivot(index=\"model1\", columns=\"model2\", values=\"score\")\n",
    "#     sns.heatmap(plotdf, ax=ax, vmin=0, vmax=1, annot=True, fmt=\".2f\", linewidths=0.5)\n",
    "#     ax.set_title(measure)\n",
    "\n",
    "fig.savefig(\"../figs/repsim_inverted.pdf\", bbox_inches=\"tight\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Agreement plot mit min/max"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "evals = pd.read_csv(\"/root/univ-data/eval_results.csv\", index_col=0).reset_index(drop=True)\n",
    "evals.loc[:, \"dataset\"] = evals[\"dataset\"].map({\"imagenet\": \"imagenet1k\", \"imagenet100\": \"imagenet100\", \"cifar10\": \"cifar10\", \"imagenet50\": \"imagenet50\", \"cifar\": \"cifar\"})\n",
    "evals = evals.drop_duplicates([\"model\", \"dataset\", \"acc\", \"trained_eps\", \"attack_lr\", \"attack_eps\"], keep=\"last\")\n",
    "\n",
    "# Exclude tiny_vit_5m results on ImageNet100 and 50, because we mistakingly used checkpoints for IN1k\n",
    "evals = evals.loc[~((evals.model == \"tiny_vit_5m\") & (evals.dataset.isin([\"imagenet100\", \"imagenet50\"])))]\n",
    "\n",
    "evals.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sns.catplot(data=evals[evals.attack_eps==0], x=\"trained_eps\", y=\"acc\", hue=\"model\", col=\"dataset\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Add column with accuracy difference to correlate with similarity metrics"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "data2 = data.copy().reset_index(drop=True)\n",
    "data2[\"acc_diff\"] = np.nan\n",
    "for index, row in data.reset_index(drop=True).iterrows():\n",
    "\n",
    "    # dataset == \"imagenet\" if row[\"dataset\"] == \"imagenet1k\" else row[\"dataset\"]\n",
    "    dataset = row[\"dataset\"]\n",
    "\n",
    "    acc1 = evals.loc[(evals.model == row[\"model1\"]) & (evals.dataset == dataset) & (evals.attack_eps == row[\"eps\"]) & (evals.trained_eps == row[\"eps\"]), \"acc\"]\n",
    "    acc2 = evals.loc[(evals.model == row[\"model2\"]) & (evals.dataset == dataset) & (evals.attack_eps == row[\"eps\"]) & (evals.trained_eps == row[\"eps\"]), \"acc\"]\n",
    "\n",
    "    if len(acc1) == 1 and len(acc2) == 1:\n",
    "        data2.loc[index, \"acc_diff\"] = abs(acc1.item() - acc2.item())\n",
    "\n",
    "data2.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 150,
   "metadata": {},
   "outputs": [],
   "source": [
    "# sns.lmplot(data2[data2.measure == \"dis\"], x=\"acc_diff\", y=\"score\", row=\"dataset\", col=\"eps\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Agreement\n",
    "#### Standard"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def disagreement_lower_bound(x, num_classes: int=1000):\n",
    "    \"\"\"Fort et al 2020. Deep Ensembles: A Loss Landscape Perspective. Apx D2\"\"\"\n",
    "    a = x[\"acc1\"]/100\n",
    "    a_star = x[\"acc2\"]/100\n",
    "\n",
    "    if np.isnan(a) or np.isnan(a_star):\n",
    "        return np.nan\n",
    "\n",
    "    if a > a_star:\n",
    "        a_star, a = a, a_star\n",
    "\n",
    "    return (num_classes - 1) * (a_star - a) / (num_classes * a_star - 1)\n",
    "\n",
    "\n",
    "def disagreement_upper_bound(x, num_classes: int=1000):\n",
    "    \"\"\"Fort et al 2020. Deep Ensembles: A Loss Landscape Perspective. Apx D1\"\"\"\n",
    "    a = x[\"acc1\"]/100\n",
    "    a_star = x[\"acc2\"]/100\n",
    "\n",
    "    if np.isnan(a) or np.isnan(a_star):\n",
    "        return np.nan\n",
    "\n",
    "    if a > a_star:\n",
    "        a_star, a = a, a_star\n",
    "\n",
    "    return (1 - a_star) * a + (1 - a) * a_star + (1 - a_star) * (1 - a) * (num_classes - 2) / (num_classes - 1)\n",
    "\n",
    "\n",
    "def combine_agreement_with_acc(data: pd.DataFrame, evals: pd.DataFrame, dataset: str, num_classes: int = 1000, input_type=\"standard\", combine_with_regular_acc: bool=True):\n",
    "    agreement_data = data.loc[(data.dataset == dataset) & (data.measure == \"dis\") & (data.input == input_type), :].copy()\n",
    "    if combine_with_regular_acc:\n",
    "        evals_subset = evals.loc[(evals.dataset == dataset) & (evals.attack_eps==0), :]\n",
    "    else:\n",
    "        evals_subset = evals.loc[(evals.dataset == dataset) & (evals.attack_eps==evals.trained_eps), :]\n",
    "\n",
    "\n",
    "    evals_subset = evals_subset.drop_duplicates([\"model\", \"dataset\", \"acc\", \"trained_eps\", \"attack_lr\", \"attack_eps\"], keep=\"last\")\n",
    "    print(len(agreement_data), len(evals_subset))\n",
    "\n",
    "    # Add accuracy information to each comparison\n",
    "    merged = pd.merge(\n",
    "        agreement_data,\n",
    "        evals_subset.rename(columns={\"model\": \"model1\", \"trained_eps\": \"eps\"}),\n",
    "        \"left\",\n",
    "        on=[\"model1\", \"eps\"],\n",
    "        suffixes=(None, \"_y\"),\n",
    "    )\n",
    "    merged = merged.drop([\"loss\", \"attack_lr\", \"attack_eps\", \"timestamp\", \"dataset_y\"], axis=\"columns\")\n",
    "    merged = merged.rename(columns={\"acc\": \"acc1\"})\n",
    "\n",
    "    merged = pd.merge(\n",
    "        merged,\n",
    "        evals_subset.rename(columns={\"model\": \"model2\", \"trained_eps\": \"eps\"}),\n",
    "        \"left\",\n",
    "        on=[\"model2\", \"eps\"],\n",
    "        suffixes=(None, \"_y\"),\n",
    "    )\n",
    "    merged = merged.drop([\"loss\", \"attack_lr\", \"attack_eps\", \"timestamp\", \"dataset_y\"], axis=\"columns\")\n",
    "    merged = merged.rename(columns={\"acc\": \"acc2\"})\n",
    "\n",
    "    # Add bounds\n",
    "    merged[\"ub\"] = 1 - merged.apply(disagreement_lower_bound, axis=1, num_classes=num_classes)\n",
    "    merged[\"lb\"] = 1 - merged.apply(disagreement_upper_bound, axis=1, num_classes=num_classes)\n",
    "    return merged\n",
    "\n",
    "\n",
    "\n",
    "n_panels = 2\n",
    "fig, axes = plt.subplots(1, n_panels, figsize=(n_panels*3*1.61, 3))\n",
    "\n",
    "fontsize=24\n",
    "# fontweight=\"normal\"\n",
    "fontweight=\"bold\"\n",
    "\n",
    "ax_idx = 0\n",
    "ax = axes[ax_idx]\n",
    "plotdata = combine_agreement_with_acc(data2, evals, \"imagenet1k\")\n",
    "sns.boxplot(data=plotdata, x=\"eps\", y=\"score\", ax=ax, **color_kwargs)\n",
    "ax.set_title(\"ImageNet1k\")\n",
    "ax.set_ylabel(\"Agreement\")\n",
    "ax.set_xlabel(r\"Robustness $\\epsilon$\")\n",
    "bounds = plotdata.groupby([\"eps\"])[[\"ub\", \"lb\"]].mean()\n",
    "ax.plot([0, 1, 2, 3, 4], bounds[\"lb\"], color=\"gray\", linestyle=\"--\")\n",
    "ax.plot([0, 1, 2, 3, 4], bounds[\"ub\"], color=\"gray\", linestyle=\"-.\")\n",
    "ax.text(3.4, 0.94, \"Avg. Upper Bound\", color=\"gray\", ha=\"center\")\n",
    "ax.text(3.4, 0.51, \"Avg. Lower Bound\", color=\"gray\", ha=\"center\")\n",
    "ax.text(x=-0.15, y=1.03, s=\"A\", va=\"bottom\", color=\"black\", transform=ax.transAxes, fontdict={\n",
    "                        'fontsize': fontsize,\n",
    "                        'fontweight': fontweight\n",
    "                    })\n",
    "ax_idx += 1\n",
    "\n",
    "# ax = axes[ax_idx]\n",
    "# plotdata = combine_agreement_with_acc(data2, evals, \"imagenet100\", num_classes=100)\n",
    "# sns.boxplot(data=plotdata, x=\"eps\", y=\"score\", ax=ax, **color_kwargs)\n",
    "# ax.set_title(\"ImageNet100\")\n",
    "# ax.set_ylabel(\"Agreement\")\n",
    "# ax.set_xlabel(r\"Robustness $\\epsilon$\")\n",
    "# bounds = plotdata.groupby([\"eps\"])[[\"ub\", \"lb\"]].mean()\n",
    "# ax.plot([0, 1, 2, 3, 4], bounds[\"lb\"], color=\"gray\", linestyle=\"--\")\n",
    "# ax.plot([0, 1, 2, 3, 4], bounds[\"ub\"], color=\"gray\", linestyle=\"-.\")\n",
    "# ax_idx += 1\n",
    "\n",
    "ax = axes[ax_idx]\n",
    "plotdata = combine_agreement_with_acc(data2, evals, \"cifar10\", num_classes=10)\n",
    "sns.boxplot(data=plotdata, x=\"eps\", y=\"score\", ax=ax, **color_kwargs)\n",
    "ax.set_title(\"CIFAR10\")\n",
    "ax.set_ylabel(\"Agreement\")\n",
    "ax.set_xlabel(r\"Robustness $\\epsilon$\")\n",
    "bounds = plotdata.groupby([\"eps\"])[[\"ub\", \"lb\"]].mean()\n",
    "ax.plot([0, 1, 2, 3], bounds[\"lb\"], color=\"gray\", linestyle=\"--\")\n",
    "ax.plot([0, 1, 2, 3], bounds[\"ub\"], color=\"gray\", linestyle=\"-.\")\n",
    "ax.text(2.8, 0.89, \"Avg. Upper Bound\", color=\"gray\", ha=\"center\")\n",
    "ax.text(2.5, 0.47, \"Avg. Lower Bound\", color=\"gray\", ha=\"right\")\n",
    "\n",
    "\n",
    "fig.savefig(\"../figs/agreement_regular.pdf\", bbox_inches=\"tight\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Inverted"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "n_panels = 2\n",
    "fig, axes = plt.subplots(1, n_panels, figsize=(n_panels*3*1.61, 3))\n",
    "\n",
    "fontsize=24\n",
    "# fontweight=\"normal\"\n",
    "fontweight=\"bold\"\n",
    "\n",
    "ax_idx = 0\n",
    "ax = axes[ax_idx]\n",
    "plotdata = combine_agreement_with_acc(data2, evals, \"imagenet1k\", input_type=\"inverted\", combine_with_regular_acc=False)\n",
    "sns.boxplot(data=plotdata, x=\"eps\", y=\"score\", ax=ax, **color_kwargs)\n",
    "ax.set_title(\"ImageNet1k\")\n",
    "ax.set_ylabel(\"Agreement\")\n",
    "ax.set_xlabel(r\"Robustness $\\epsilon$\")\n",
    "ax.text(x=-0.15, y=1.03, s=\"A\", va=\"bottom\", color=\"black\", transform=ax.transAxes, fontdict={\n",
    "                        'fontsize': fontsize,\n",
    "                        'fontweight': fontweight\n",
    "                    })\n",
    "ax_idx += 1\n",
    "\n",
    "# ax = axes[ax_idx]\n",
    "# plotdata = combine_agreement_with_acc(data2, evals, \"imagenet100\", num_classes=100, input_type=\"inverted\", combine_with_regular_acc=False)\n",
    "# sns.boxplot(data=plotdata, x=\"eps\", y=\"score\", ax=ax, hue=\"input\", palette=palette, legend=False)\n",
    "# ax.set_title(\"ImageNet100\")\n",
    "# ax.set_ylabel(\"Agreement\")\n",
    "# ax.set_xlabel(r\"Robustness $\\epsilon$\")\n",
    "# ax_idx += 1\n",
    "\n",
    "ax = axes[ax_idx]\n",
    "plotdata = combine_agreement_with_acc(data2, evals, \"cifar10\", num_classes=10, input_type=\"inverted\", combine_with_regular_acc=False)\n",
    "sns.boxplot(data=plotdata, x=\"eps\", y=\"score\", ax=ax, **color_kwargs)\n",
    "ax.set_title(\"CIFAR10\")\n",
    "ax.set_ylabel(\"Agreement\")\n",
    "ax.set_xlabel(r\"Robustness $\\epsilon$\")\n",
    "\n",
    "fig.savefig(\"../figs/agreement_inverted.pdf\", bbox_inches=\"tight\")\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### JSD\n",
    "\n",
    "#### Standard"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "n_panels = 2\n",
    "fig, axes = plt.subplots(1, n_panels, figsize=(n_panels*3*1.61, 3))\n",
    "\n",
    "ax_idx = 0\n",
    "ax = axes[ax_idx]\n",
    "plotdata = data2.loc[(data2.dataset == \"imagenet1k\") & (data2.measure == \"jsd\") & (data2.input == \"standard\")]\n",
    "sns.boxplot(data=plotdata, x=\"eps\", y=\"score\", ax=ax, **color_kwargs)\n",
    "ax.set_title(\"ImageNet1k\")\n",
    "ax.set_ylabel(\"JSDSim\")\n",
    "ax.set_xlabel(r\"Robustness $\\epsilon$\")\n",
    "ax.text(x=-0.15, y=1.03, s=\"B\", va=\"bottom\", color=\"black\", transform=ax.transAxes, fontdict={\n",
    "                        'fontsize': fontsize,\n",
    "                        'fontweight': fontweight\n",
    "                    })\n",
    "\n",
    "ax_idx += 1\n",
    "\n",
    "ax = axes[ax_idx]\n",
    "plotdata = data2.loc[(data2.dataset == \"cifar10\") & (data2.measure == \"jsd\") & (data2.input == \"standard\")]\n",
    "sns.boxplot(data=plotdata, x=\"eps\", y=\"score\", ax=ax, **color_kwargs)\n",
    "ax.set_title(\"CIFAR10\")\n",
    "ax.set_ylabel(\"\")\n",
    "ax.set_xlabel(r\"Robustness $\\epsilon$\")\n",
    "ax_idx += 1\n",
    "\n",
    "# plotdata = data2.loc[(data2.dataset == \"imagenet100\") & (data2.measure == \"jsd\") & (data2.input == \"standard\")]\n",
    "# sns.boxplot(data=plotdata, x=\"eps\", y=\"score\", ax=ax)\n",
    "# ax.set_title(\"ImageNet100\")\n",
    "# ax.set_ylabel(\"\")\n",
    "# ax.set_xlabel(r\"Robustness $\\epsilon$\")\n",
    "# ax_idx += 1\n",
    "\n",
    "fig.savefig(\"../figs/jsd_regular.pdf\", bbox_inches=\"tight\")\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Inverted"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "n_panels = 2\n",
    "fig, axes = plt.subplots(1, n_panels, figsize=(n_panels*3*1.61, 3))\n",
    "\n",
    "ax_idx = 0\n",
    "ax = axes[ax_idx]\n",
    "plotdata = data2.loc[(data2.dataset == \"imagenet1k\") & (data2.measure == \"jsd\") & (data2.input == \"inverted\")]\n",
    "sns.boxplot(data=plotdata, x=\"eps\", y=\"score\", ax=ax, **color_kwargs)\n",
    "ax.set_title(\"ImageNet1k\")\n",
    "ax.set_ylabel(\"JSDSim\")\n",
    "ax.set_xlabel(r\"Robustness $\\epsilon$\")\n",
    "ax.text(x=-0.15, y=1.03, s=\"B\", va=\"bottom\", color=\"black\", transform=ax.transAxes, fontdict={\n",
    "                        'fontsize': fontsize,\n",
    "                        'fontweight': fontweight\n",
    "                    })\n",
    "\n",
    "ax_idx += 1\n",
    "\n",
    "ax = axes[ax_idx]\n",
    "plotdata = data2.loc[(data2.dataset == \"cifar10\") & (data2.measure == \"jsd\") & (data2.input == \"inverted\")]\n",
    "sns.boxplot(data=plotdata, x=\"eps\", y=\"score\", ax=ax, **color_kwargs)\n",
    "ax.set_title(\"CIFAR10\")\n",
    "ax.set_ylabel(\"\")\n",
    "ax.set_xlabel(r\"Robustness $\\epsilon$\")\n",
    "ax_idx += 1\n",
    "\n",
    "# plotdata = data2.loc[(data2.dataset == \"imagenet100\") & (data2.measure == \"jsd\") & (data2.input == \"inverted\")]\n",
    "# sns.boxplot(data=plotdata, x=\"eps\", y=\"score\", ax=ax)\n",
    "# ax.set_title(\"ImageNet100\")\n",
    "# ax.set_ylabel(\"\")\n",
    "# ax.set_xlabel(r\"Robustness $\\epsilon$\")\n",
    "# ax_idx += 1\n",
    "\n",
    "fig.savefig(\"../figs/jsd_inverted.pdf\", bbox_inches=\"tight\")\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Dataset Size Plots"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "datasets = [\"imagenet50\", \"imagenet100\", \"imagenet1k\"]\n",
    "input_type = \"inverted\"\n",
    "\n",
    "plotdata = data.loc[(data.input==input_type) & data.dataset.isin(datasets)].copy()\n",
    "plotdata[\"num_classes\"] = plotdata[\"dataset\"].map({\"imagenet1k\": 1000, \"imagenet100\": 100, \"imagenet50\": 50})\n",
    "plotdata[\"dataset\"] = plotdata[\"dataset\"].map({\"imagenet1k\": \"IN1k\", \"imagenet100\": \"IN100\", \"imagenet50\": \"IN50\"})\n",
    "\n",
    "\n",
    "# sns.catplot(plotdata, x=\"eps\", y=\"score\", col=\"measure\", hue=\"dataset\", hue_order=datasets, kind=\"box\", sharey=False)\n",
    "sns.catplot(plotdata, hue=\"eps\", y=\"score\", col=\"measure\", x=\"dataset\", kind=\"box\", sharey=False, palette=\"flare\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "datasets = [\"imagenet50\", \"imagenet100\", \"imagenet1k\"]\n",
    "input_type = \"inverted\"\n",
    "\n",
    "plotdata = data.loc[(data.input==input_type) & data.dataset.isin(datasets)].copy()\n",
    "plotdata[\"num_classes\"] = plotdata[\"dataset\"].map({\"imagenet1k\": 1000, \"imagenet100\": 100, \"imagenet50\": 50})\n",
    "plotdata[\"dataset\"] = plotdata[\"dataset\"].map({\"imagenet1k\": \"IN1k\", \"imagenet100\": \"IN100\", \"imagenet50\": \"IN50\"})\n",
    "\n",
    "n_panels = 3\n",
    "fig, axes = plt.subplots(1, n_panels, figsize=(n_panels*3*1.61, 3))\n",
    "\n",
    "fontsize=24\n",
    "# fontweight=\"normal\"\n",
    "fontweight=\"bold\"\n",
    "\n",
    "ax_idx = 0\n",
    "\n",
    "subplotdata = plotdata[plotdata.measure == \"cka\"]\n",
    "ax = axes[ax_idx]\n",
    "sns.boxplot(data=subplotdata, x=\"dataset\", y=\"score\", hue=\"eps\", ax=ax, legend=False)\n",
    "ax.set_title(\"CKA\")\n",
    "ax.set_ylabel(\"Similarity\")\n",
    "ax.set_xlabel(\"Datasets\")\n",
    "ax_idx += 1\n",
    "\n",
    "subplotdata = plotdata[plotdata.measure == \"proc\"]\n",
    "ax = axes[ax_idx]\n",
    "sns.boxplot(data=subplotdata, x=\"dataset\", y=\"score\", hue=\"eps\", ax=ax, legend=False)\n",
    "ax.set_title(\"ProcrustesSim\")\n",
    "ax.set_ylabel(\"\")\n",
    "ax.set_xlabel(\"Datasets\")\n",
    "ax_idx += 1\n",
    "\n",
    "subplotdata = plotdata[plotdata.measure == \"jaccard\"]\n",
    "ax = axes[ax_idx]\n",
    "sns.boxplot(data=subplotdata, x=\"dataset\", y=\"score\", hue=\"eps\", ax=ax, legend=True)\n",
    "ax.set_title(\"Jaccard\")\n",
    "ax.set_ylabel(\"\")\n",
    "ax.set_xlabel(\"Datasets\")\n",
    "sns.move_legend(ax, loc=\"right\", bbox_to_anchor=(1.4, 0.5))\n",
    "ax_idx += 1\n",
    "\n",
    "fig.savefig(\"../figs/imagenets_size_repsim_inverted.pdf\", bbox_inches=\"tight\")\n",
    "\n",
    "\n",
    "## predictions\n",
    "n_panels = 2\n",
    "fig, axes = plt.subplots(1, n_panels, figsize=(n_panels*3*1.61, 3))\n",
    "\n",
    "fontsize=24\n",
    "# fontweight=\"normal\"\n",
    "fontweight=\"bold\"\n",
    "\n",
    "ax_idx = 0\n",
    "\n",
    "subplotdata = plotdata[plotdata.measure == \"dis\"]\n",
    "ax = axes[ax_idx]\n",
    "sns.boxplot(data=subplotdata, x=\"dataset\", y=\"score\", hue=\"eps\", ax=ax, legend=False)\n",
    "ax.set_title(\"Agreement\")\n",
    "ax.set_ylabel(\"Similarity\")\n",
    "ax.set_xlabel(\"Datasets\")\n",
    "ax_idx += 1\n",
    "\n",
    "subplotdata = plotdata[plotdata.measure == \"jsd\"]\n",
    "ax = axes[ax_idx]\n",
    "sns.boxplot(data=subplotdata, x=\"dataset\", y=\"score\", hue=\"eps\", ax=ax, legend=True)\n",
    "ax.set_title(\"JSDSim\")\n",
    "ax.set_ylabel(\"\")\n",
    "ax.set_xlabel(\"Datasets\")\n",
    "ax_idx += 1\n",
    "\n",
    "sns.move_legend(ax, loc=\"right\", bbox_to_anchor=(1.4, 0.5))\n",
    "\n",
    "\n",
    "fig.savefig(\"../figs/imagenets_size_preds_inverted.pdf\", bbox_inches=\"tight\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "datasets = [\"imagenet50\", \"imagenet100\", \"imagenet1k\"]\n",
    "input_type = \"standard\"\n",
    "\n",
    "plotdata = data.loc[(data.input==input_type) & data.dataset.isin(datasets)].copy()\n",
    "plotdata[\"num_classes\"] = plotdata[\"dataset\"].map({\"imagenet1k\": 1000, \"imagenet100\": 100, \"imagenet50\": 50})\n",
    "plotdata[\"dataset\"] = plotdata[\"dataset\"].map({\"imagenet1k\": \"IN1k\", \"imagenet100\": \"IN100\", \"imagenet50\": \"IN50\"})\n",
    "\n",
    "n_panels = 2\n",
    "fig, axes = plt.subplots(1, n_panels, figsize=(n_panels*3*1.61, 3))\n",
    "\n",
    "fontsize=24\n",
    "# fontweight=\"normal\"\n",
    "fontweight=\"bold\"\n",
    "\n",
    "ax_idx = 0\n",
    "\n",
    "subplotdata = plotdata[plotdata.measure == \"dis\"]\n",
    "ax = axes[ax_idx]\n",
    "sns.boxplot(data=subplotdata, x=\"dataset\", y=\"score\", hue=\"eps\", ax=ax, legend=False, palette=\"crest\")\n",
    "ax.set_title(\"Agreement\")\n",
    "ax.set_ylabel(\"Similarity\")\n",
    "ax.set_xlabel(\"Datasets\")\n",
    "ax_idx += 1\n",
    "\n",
    "subplotdata = plotdata[plotdata.measure == \"jsd\"]\n",
    "ax = axes[ax_idx]\n",
    "sns.boxplot(data=subplotdata, x=\"dataset\", y=\"score\", hue=\"eps\", ax=ax, legend=True, palette=\"crest\")\n",
    "ax.set_title(\"JSDSim\")\n",
    "ax.set_ylabel(\"\")\n",
    "ax.set_xlabel(\"Datasets\")\n",
    "ax_idx += 1\n",
    "\n",
    "sns.move_legend(ax, loc=\"right\", bbox_to_anchor=(1.4, 0.5))\n",
    "\n",
    "\n",
    "fig.savefig(\"../figs/imagenets_size_preds_regular.pdf\", bbox_inches=\"tight\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "robust-transfer",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.19"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
