{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b000bdde-c9e5-4aed-a992-5adc5213b6fd",
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "import pandas as pd \n",
    "import seaborn as sns\n",
    "import wandb\n",
    "\n",
    "%config Completer.use_jedi = False\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "86358c41-cde5-480b-a2ad-925ab667fd03",
   "metadata": {},
   "outputs": [],
   "source": [
    "pgf_with_rc_fonts = {\n",
    "    \"font.serif\": [],                   # use latex default serif font\n",
    "    \"font.sans-serif\": [\"DejaVu Sans\"], # use a specific sans-serif font\n",
    "    \"font.size\": 12,\n",
    "    \"ps.useafm\": True,\n",
    "    \"pdf.use14corefonts\": True,\n",
    "    \"text.usetex\": True,\n",
    "}\n",
    "matplotlib.rcParams.update(pgf_with_rc_fonts)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "048e0a76-3e1b-445e-b4fc-8b6ad501124c",
   "metadata": {},
   "outputs": [],
   "source": [
    "api = wandb.Api()\n",
    "entity, project = \"INPUT_YOUR_ENTITY\", \"curl\"\n",
    "runs = api.runs(entity + \"/\" + project) "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a863f7fa-423d-4d4b-960e-c9e332a50cac",
   "metadata": {},
   "outputs": [],
   "source": [
    "def extract_best_checkpoint_and_epoch(output_model_name):\n",
    "    elements = output_model_name.split(\"_\")\n",
    "    if elements[-1] == \"model.pt\":\n",
    "        return output_model_name, \"best\"\n",
    "    else:\n",
    "        checkpoint = elements[-1].split(\".\")[0]\n",
    "        \n",
    "        elements[3] += \".pt\"\n",
    "        \n",
    "        return \"_\".join(elements[:4]), checkpoint\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "07c694ef-466a-42d4-abd0-7ac9c66658f3",
   "metadata": {},
   "outputs": [],
   "source": [
    "datasets = [\"wiki3029\", \"cifar10\", \"cifar100\"]\n",
    "eval_records = []\n",
    "contrastive_records = []\n",
    "classifiers = (\"linear\", \"mean\")\n",
    "\n",
    "\n",
    "for run in runs: \n",
    "    if \"hydra_path\" not in run.config:\n",
    "        continue\n",
    "\n",
    "    dataset = run.config[\"dataset.name\"]\n",
    "    if dataset not in datasets:\n",
    "        continue\n",
    "        \n",
    "    if run.config[\"name\"] in classifiers:\n",
    "        if run.config[\"normalize\"]:\n",
    "            continue\n",
    "\n",
    "        extracted_keys = (\"supervised_test_acc\", \"supervised_val_acc\")\n",
    "        for k in extracted_keys:\n",
    "            run.config[k] = run.summary[k]        \n",
    "\n",
    "        run.config[\"target_weight_file\"], run.config[\"checkpoint\"] = extract_best_checkpoint_and_epoch(run.config[\"target_weight_file\"])\n",
    "        eval_records.append(run.config)\n",
    "        \n",
    "    elif run.config[\"name\"] == \"contrastive\":\n",
    "\n",
    "        v = run.config[\"hydra_path\"] + \"/\" + run.config[\"output_model_name\"]\n",
    "\n",
    "        run.config[\"target_weight_file\"] = v\n",
    "        contrastive_records.append(run.config)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ec514e17-8b01-46af-9049-046defc91cb2",
   "metadata": {},
   "outputs": [],
   "source": [
    "contrastive_df = pd.DataFrame.from_records(contrastive_records)\n",
    "eval_df = pd.DataFrame.from_records(eval_records)\n",
    "results_df = contrastive_df.merge(eval_df, how=\"inner\", on=\"target_weight_file\", suffixes=(\"\", \"_y\"))\n",
    "\n",
    "used_columns = [\"seed\", \"dataset.name\", \"supervised_test_acc\", \"supervised_val_acc\", \"dataset.num_used_classes\", \"optimizer.lr\", \"loss.neg_size\", \"epochs\", \"name_y\", \"checkpoint\"]\n",
    "results_df.drop(labels=[k for k in results_df.keys() if k not in used_columns ], inplace=True, axis=1)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "27ca4711-b762-4a55-ab71-7df85cfd663f",
   "metadata": {},
   "outputs": [],
   "source": [
    "removed_prefix_columns = []\n",
    "for c in results_df.columns:\n",
    "    removed_prefix_columns.append(c.split(\".\")[-1].replace(\"_\", \"-\"))\n",
    "results_df.columns = removed_prefix_columns\n",
    "\n",
    "rename = {\"wiki3029\": \"Wiki-3029\", \"cifar10\": \"CIFAR-10\", \"cifar100\": \"CIFAR-100\"}"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c6b74613-276d-4591-a1de-26f7510a3240",
   "metadata": {},
   "source": [
    "## Wiki3029"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "66ce0cf1-d29a-4bf6-9761-a6f3e89e5449",
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset = \"wiki3029\"\n",
    "\n",
    "for classifier in [\"mean\", \"linear\"]:\n",
    "    _results_df = results_df[results_df[\"name\"] == dataset]\n",
    "    _results_df = _results_df[_results_df[\"checkpoint\"] == \"best\"]        \n",
    "    df_per_classifer = _results_df[_results_df[\"name-y\"] == classifier]\n",
    "    idx = df_per_classifer.groupby([\"seed\", \"num-used-classes\", \"neg-size\"])[\"supervised-val-acc\"].idxmax()\n",
    "    df_per_classifer = df_per_classifer.loc[idx,]\n",
    "    mean = df_per_classifer.groupby([\"num-used-classes\", \"neg-size\"]).mean().reset_index()\n",
    "    std = df_per_classifer.groupby([\"num-used-classes\", \"neg-size\"]).std().reset_index()    \n",
    "\n",
    "    mean = mean.pivot(index='num-used-classes', columns='neg-size', values='supervised-test-acc', )\n",
    "    std = std.pivot(index='num-used-classes', columns='neg-size', values='supervised-test-acc', )    \n",
    "\n",
    "    mean = mean.sort_values(by=\"num-used-classes\", ascending=False)\n",
    "    std = std.sort_values(by=\"num-used-classes\", ascending=False)    \n",
    "\n",
    "    Ks = mean.columns\n",
    "    Cs = mean.index\n",
    "    data = mean.to_numpy()\n",
    "\n",
    "    plt.imshow(data)\n",
    "\n",
    "    plt.yticks(np.arange(len(Cs)), [r\"${}$\".format(c) for c in sorted(Cs, reverse=True)])\n",
    "    plt.xticks(np.arange(len(Ks)), [r\"${}$\".format(k) for k in sorted(list(Ks))])\n",
    "\n",
    "    for c in range(len(Cs)):\n",
    "        for k in range(len(Ks)):\n",
    "            if c <= 1:\n",
    "                color = \"white\"\n",
    "            else:\n",
    "                color = \"black\"\n",
    "            plt.text(k, c, \"${:.2f}$\".format(data[c, k]),\n",
    "                     ha=\"center\", va=\"bottom\", color=color)\n",
    "            plt.text(k, c, \"$({:.2f})$\".format(std.to_numpy()[c, k]),\n",
    "                     ha=\"center\", va=\"top\", color=color)            \n",
    "    plt.title(\"{} {} Classifier\".format(rename[dataset], classifier.capitalize()))\n",
    "    plt.xlabel(r\"$K$\")\n",
    "    plt.ylabel(r\"$C$\")    \n",
    "    plt.savefig(\"../../papers/figures/heatmap-{}-{}.pdf\".format(dataset, classifier))\n",
    "    plt.show()            "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4450658e-1603-4a34-8402-6c3f718b7246",
   "metadata": {},
   "source": [
    "## CIFAR-10/100\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ab1771cf-fac4-4286-9b0d-0a5a588c8772",
   "metadata": {},
   "outputs": [],
   "source": [
    "for classifier in [\"mean\", \"linear\"]:\n",
    "    mean_dfs = []\n",
    "    std_dfs = []\n",
    "    for dataset in set(results_df.name.values) - {\"wiki3029\"}:\n",
    "        _results_df = results_df[results_df[\"name\"] == dataset]\n",
    "        _results_df = _results_df[_results_df[\"checkpoint\"] == \"best\"]        \n",
    "        df_per_classifer = _results_df[_results_df[\"name-y\"] == classifier]\n",
    "        idx = df_per_classifer.groupby([\"seed\", \"num-used-classes\", \"neg-size\"])[\"supervised-val-acc\"].idxmax()\n",
    "        df_per_classifer = df_per_classifer.loc[idx,]\n",
    "        mean = df_per_classifer.groupby([\"num-used-classes\", \"neg-size\"]).mean().reset_index()\n",
    "        std = df_per_classifer.groupby([\"num-used-classes\", \"neg-size\"]).std().reset_index()    \n",
    "\n",
    "        mean = mean.pivot(index='num-used-classes', columns='neg-size', values='supervised-test-acc', )\n",
    "        std = std.pivot(index='num-used-classes', columns='neg-size', values='supervised-test-acc', )\n",
    "\n",
    "\n",
    "        \n",
    "        mean_dfs.append(mean)\n",
    "        std_dfs.append(std)\n",
    "    mean = pd.concat(mean_dfs)\n",
    "    std = pd.concat(std_dfs)\n",
    "    \n",
    "    mean = mean.sort_values(by=\"num-used-classes\", ascending=False)\n",
    "    std = std.sort_values(by=\"num-used-classes\", ascending=False)        \n",
    "    \n",
    "    Ks = mean.columns\n",
    "    Cs = mean.index\n",
    "    data = mean.to_numpy()\n",
    "\n",
    "    plt.imshow(data)\n",
    "\n",
    "    plt.yticks(np.arange(len(Cs)), [r\"${}$\".format(c) for c in sorted(Cs, reverse=True)])\n",
    "    plt.xticks(np.arange(len(Ks)), [r\"${}$\".format(k) for k in sorted(list(Ks))])\n",
    "\n",
    "    for c in range(len(Cs)):\n",
    "        for k in range(len(Ks)):\n",
    "            if c == 0:\n",
    "                color = \"white\"\n",
    "            else:\n",
    "                color = \"black\"\n",
    "            plt.text(k, c, \"${:.2f}$\".format(data[c, k]),\n",
    "                     ha=\"center\", va=\"bottom\", color=color)\n",
    "            plt.text(k, c, \"$({:.2f})$\".format(std.to_numpy()[c, k]),\n",
    "                     ha=\"center\", va=\"top\", color=color)            \n",
    "    plt.title(\"CIFAR {} Classifier\".format(classifier.capitalize()))\n",
    "    plt.xlabel(r\"$K$\")\n",
    "    plt.ylabel(r\"$C$\")    \n",
    "    plt.savefig(\"../../papers/figures/heatmap-cifar-{}.pdf\".format(classifier))\n",
    "    plt.show()            "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "754c6b3e-8d85-41ec-bcfb-402c70c4fe71",
   "metadata": {},
   "source": [
    "### Mean classifier's test accuracy by checkpoints"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f25471a2-8449-42b1-8d64-ffa6544a5f36",
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "sys.path.append(\"../../scripts/\")\n",
    "import style\n",
    "\n",
    "\n",
    "for dataset in set(results_df.name.values):\n",
    "    if dataset == \"wiki3029\":\n",
    "        continue\n",
    "        \n",
    "    _results_df = results_df[results_df[\"name\"] == dataset]\n",
    "    _results_df = _results_df[_results_df[\"checkpoint\"] != \"best\"]\n",
    "    _results_df[\"checkpoint\"] = pd.to_numeric(_results_df[\"checkpoint\"])    \n",
    "    \n",
    "    df_per_classifer = _results_df[_results_df[\"name-y\"] == \"mean\"]\n",
    "    idx = df_per_classifer.groupby([\"seed\", \"checkpoint\", \"neg-size\"])[\"supervised-val-acc\"].idxmax()\n",
    "    df_per_classifer = df_per_classifer.loc[idx,]\n",
    "    mean = df_per_classifer.groupby([\"checkpoint\", \"neg-size\"]).mean().reset_index()\n",
    "    std = df_per_classifer.groupby([\"checkpoint\", \"neg-size\"]).std().reset_index()    \n",
    "\n",
    "    mean = mean.sort_values(by=\"checkpoint\", ascending=False)\n",
    "    std = std.sort_values(by=\"checkpoint\", ascending=False)    \n",
    "\n",
    "    plot = sns.barplot(x=\"checkpoint\", y=\"supervised-test-acc\", hue=\"neg-size\", data=mean)\n",
    "    plt.ylabel(\"Mean classifier's test accuracy\")\n",
    "    plt.xlabel(\"Epochs\")\n",
    "    plt.legend(loc=\"lower left\", title=r\"$K$\")\n",
    "    plt.savefig(\"../../papers/figures/{}-mean-performance-by-epochs.pdf\".format(dataset))\n",
    "    plt.title(rename[dataset])\n",
    "    plt.show()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2d56372f-5dd3-4118-a6d0-512afad7a504",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
