{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "248ecc27",
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "import pandas as pd \n",
    "import wandb\n",
    "import sys\n",
    "\n",
    "%config Completer.use_jedi = False\n",
    "%matplotlib inline\n",
    "\n",
    "sys.path.append(\"../../scripts/\")\n",
    "import style"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3c75116a",
   "metadata": {},
   "outputs": [],
   "source": [
    "api = wandb.Api()\n",
    "entity, project = \"INPUT_YOUR_ENTITY\", \"curl\"\n",
    "runs = api.runs(entity + \"/\" + project) "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "be52edd7",
   "metadata": {},
   "outputs": [],
   "source": [
    "datasets = [\"wiki3029\", \"cifar10\", \"cifar100\"]\n",
    "classifier = \"mean\"\n",
    "\n",
    "eval_records = []\n",
    "contrastive_records = []\n",
    "contrastive_eval_records = []\n",
    "\n",
    "for run in runs: \n",
    "    \n",
    "    if \"hydra_path\" not in run.config:\n",
    "        continue\n",
    "\n",
    "    if \"k36049\" not in run.config[\"hydra_path\"]:\n",
    "        continue\n",
    "        \n",
    "    dataset = run.config[\"dataset.name\"]\n",
    "    if dataset not in datasets:\n",
    "        continue\n",
    "        \n",
    "    if run.config[\"name\"] == classifier:\n",
    "        \n",
    "        if not run.config[\"normalize\"]:\n",
    "            continue\n",
    "\n",
    "        extracted_keys = (\"supervised_test_acc\", \"supervised_val_acc\", \"supervised_test_loss\")\n",
    "        \n",
    "        for k in extracted_keys:\n",
    "            run.config[k] = run.summary[k]        \n",
    "\n",
    "        eval_records.append(run.config)\n",
    "        \n",
    "    elif run.config[\"name\"] == \"contrastive\":\n",
    "\n",
    "        v = run.config[\"hydra_path\"] + \"/\" + run.config[\"output_model_name\"]\n",
    "        run.config[\"target_weight_file\"] = v\n",
    "        contrastive_records.append(run.config)\n",
    "\n",
    "    elif run.config[\"name\"] == \"contrastive_eval\":\n",
    "        \n",
    "        extracted_keys = (\"contrastive_val_loss\", \"contrastive_test_loss\")\n",
    "        \n",
    "        for k in extracted_keys:\n",
    "            run.config[k] = run.summary[k]\n",
    "\n",
    "        contrastive_eval_records.append(run.config)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f6fa8b3e",
   "metadata": {},
   "outputs": [],
   "source": [
    "contrastive_df = pd.DataFrame.from_records(contrastive_records)\n",
    "eval_df = pd.DataFrame.from_records(eval_records)\n",
    "contrastive_eval_df = pd.DataFrame.from_records(contrastive_eval_records)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9f2a0222",
   "metadata": {},
   "outputs": [],
   "source": [
    "merged_df = contrastive_df.merge(eval_df, how=\"inner\", on=\"target_weight_file\", suffixes=(\"\", \"_y\"))\n",
    "\n",
    "used_columns = [\n",
    "    \"seed\", \"dataset.name\", \"dataset.num_used_classes\", \"optimizer.lr\", \"loss.neg_size\",\n",
    "    \"supervised_test_loss\", \"supervised_val_loss\", \"supervised_val_acc\", \"target_weight_file\"\n",
    "]\n",
    "merged_df.drop(labels=[k for k in merged_df.keys() if k not in used_columns ], inplace=True, axis=1)\n",
    "merged_df.head(1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b19439cf",
   "metadata": {},
   "outputs": [],
   "source": [
    "merged_df = merged_df.merge(contrastive_eval_df, how=\"inner\", on=\"target_weight_file\", suffixes=(\"\", \"_y\"))\n",
    "\n",
    "used_columns = used_columns[:-1] + [\"contrastive_test_loss\", \"supervised_test_loss\"]\n",
    "merged_df.drop(labels=[k for k in merged_df.keys() if k not in used_columns ], inplace=True, axis=1)\n",
    "merged_df.head(1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "85ecf00d",
   "metadata": {},
   "outputs": [],
   "source": [
    "removed_prefix_columns = []\n",
    "\n",
    "# latex is unhappy with `_`, so replace them.\n",
    "for c in merged_df.columns:\n",
    "    removed_prefix_columns.append(c.split(\".\")[-1].replace(\"_\", \"-\"))\n",
    "\n",
    "merged_df.columns = removed_prefix_columns\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "05502106",
   "metadata": {},
   "outputs": [],
   "source": [
    "from plot_k_msuploss import ub_intercept\n",
    "from compare_upper_bound import collision_term\n",
    "\n",
    "\n",
    "def tau(k, c):\n",
    "    \"\"\"\n",
    "    collision probability.\n",
    "    \"\"\"\n",
    "    return (1. - (1. - 1. / c) ** k)\n",
    "\n",
    "def ccp(k, c):    \n",
    "    \"\"\"\n",
    "    coupon collector problem's probabilty;\n",
    "    func2 on http://aquarius10.cse.kyutech.ac.jp/~otabe/shokugan/sg2.html\n",
    "    \"\"\"\n",
    "\n",
    "    \n",
    "    is_returned_single_value = isinstance(k, int)\n",
    "    if is_returned_single_value:\n",
    "        k = np.array([k])\n",
    "\n",
    "    ret = []\n",
    "    for _k in k:\n",
    "        p = np.zeros(c + 1)\n",
    "        p[0] = 1\n",
    "        for j in range(_k):\n",
    "            for i in range(c, -1, -1):\n",
    "                p[i] = p[i] * i / c + p[i - 1] * (c - i + 1.) / c\n",
    "                if i == 0:\n",
    "                    p[0] = 0\n",
    "\n",
    "        ret.append(p[c])\n",
    "\n",
    "    ret = np.array(ret)\n",
    "\n",
    "    if is_returned_single_value:\n",
    "        return ret[0]\n",
    "    else:\n",
    "        return ret\n",
    "\n",
    "\n",
    "def collision(k, c):\n",
    "    \"\"\"\n",
    "    Collision term\n",
    "    \"\"\"\n",
    "    return collision_term(np.array([k]), c)[0]\n",
    "\n",
    "harmonic = lambda c: np.log(c) + 0.577\n",
    "\n",
    "\n",
    "def arora(k, c, loss):\n",
    "    if c > k:\n",
    "        return None\n",
    "    \n",
    "    # since k is the number of negative samples,\n",
    "    # we added the positive classes to k: k+1\n",
    "    v = ccp(k+1, c)\n",
    "    coeff = 1. / ((1. - tau(k, c)) * v)\n",
    "    return coeff * (loss - collision(k, c))\n",
    "\n",
    "def nozawa(k, c, loss):\n",
    "    if c > k:\n",
    "        return None  \n",
    "\n",
    "    # since k is the number of negative samples,\n",
    "    # we added the positive classes to k: k+1\n",
    "    v = ccp(k+1, c)\n",
    "    return (2. * loss - collision(k, c)) / v\n",
    "\n",
    "def ash(k, c, loss):\n",
    "\n",
    "    coeff = 2 * np.maximum(1, 2 * (c - 1) * harmonic(c - 1) / k) / (1 - tau(k, c))\n",
    "    return coeff * (loss - collision(k, c))\n",
    "\n",
    "def ours(k, c, loss):\n",
    "    return loss + ub_intercept(k, c, b=1.0)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f5fcdbc0",
   "metadata": {},
   "outputs": [],
   "source": [
    "rename = {\"wiki3029\": \"Wiki-3029\", \"cifar100\": \"CIFAR-100\", \"cifar10\": \"CIFAR-10\"}\n",
    "\n",
    "for dataset in set(merged_df.name.values):\n",
    "\n",
    "    idx = merged_df[merged_df[\"name\"] == dataset].groupby([\"seed\", \"num-used-classes\", \"neg-size\"])[\"supervised-val-acc\"].idxmax()\n",
    "    df = merged_df.loc[idx,]\n",
    "    \n",
    "    negs = df[\"neg-size\"].values\n",
    "    seeds = df[\"seed\"].values\n",
    "    Cs = df[\"num-used-classes\"].values\n",
    "    test_contrastive_losses = df[\"contrastive-test-loss\"].values\n",
    "    test_sup_loss = df[\"supervised-test-loss\"].values\n",
    "\n",
    "    data = []\n",
    "    for row in df.iterrows():\n",
    "        row = row[1]\n",
    "        seed = row[\"seed\"]        \n",
    "        K = row[\"neg-size\"]\n",
    "        C = row[\"num-used-classes\"]\n",
    "        supervised_loss = row[\"supervised-test-loss\"]\n",
    "        contrastive_loss = row[\"contrastive-test-loss\"]\n",
    "        new_d_row = [\n",
    "            seed, C, K, supervised_loss,\n",
    "            arora(K, C, contrastive_loss),\n",
    "            nozawa(K, C, contrastive_loss),\n",
    "            ash(K, C, contrastive_loss),\n",
    "            ours(K, C, contrastive_loss)\n",
    "        ]\n",
    "        \n",
    "        data.append(new_d_row)\n",
    "    \n",
    "    df = pd.DataFrame(data, columns=[\"seed\", \"C\", \"K\", \"sup_loss\", \"arora\", \"nozawa\", \"ash\", \"ours\"])\n",
    "    df = df.groupby([\"C\", \"K\"]).mean().reset_index()\n",
    "\n",
    "    for c in set(df[\"C\"].values):\n",
    "        _df = df[df[\"C\"] == c]\n",
    "        marker = \"o\"\n",
    "        plt.figure(figsize=(4, 3.5))\n",
    "        plt.plot(_df[\"ours\"].values, \"-\", color=\"C0\", marker=marker, label=\"Ours\")\n",
    "        plt.plot(_df[\"arora\"].values, \"-\", color=\"C1\", marker=marker, label=\"Arora et al.\")\n",
    "        plt.plot(_df[\"nozawa\"].values, \"-\", color=\"C2\", marker=marker, label=\"Nozawa \\& Sato\")\n",
    "        plt.plot(_df[\"ash\"].values, \"-\", color=\"C3\", marker=marker, label=\"Ash et al.\")\n",
    "        plt.plot(_df[\"sup_loss\"].values, \"x\", color=\"black\", markersize=13, label=\"Supervised loss\")        \n",
    "        plt.yscale(\"log\")\n",
    "        plt.xticks(np.arange(len(_df[\"K\"].values)), _df[\"K\"].values)\n",
    "        plt.legend(fontsize=10)\n",
    "\n",
    "        plt.xlabel(r\"$K$\")\n",
    "\n",
    "        if dataset == \"cifar10\":\n",
    "            plt.savefig(\"../../papers/figures/{}-upper-bound.pdf\".format(dataset))\n",
    "\n",
    "        # show title only for notebook\n",
    "        title = \"{}, C={}\".format(rename[dataset], c)\n",
    "        plt.title(title)\n",
    "        plt.show()\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a9b73f51-e54b-4d4a-b85d-28e55e7f5695",
   "metadata": {},
   "source": [
    "# Closer look of `ours` and `supervised loss`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a0830958-3f57-49f9-b78e-0f17cf9c10fd",
   "metadata": {},
   "outputs": [],
   "source": [
    "rename = {\"wiki3029\": \"Wiki-3029\", \"cifar100\": \"CIFAR-100\", \"cifar10\": \"CIFAR-10\"}\n",
    "\n",
    "for dataset in set(merged_df.name.values):\n",
    "\n",
    "    idx = merged_df[merged_df[\"name\"] == dataset].groupby([\"seed\", \"num-used-classes\", \"neg-size\"])[\"supervised-val-acc\"].idxmax()\n",
    "    df = merged_df.loc[idx,]\n",
    "    \n",
    "    negs = df[\"neg-size\"].values\n",
    "    seeds = df[\"seed\"].values\n",
    "    Cs = df[\"num-used-classes\"].values\n",
    "    test_contrastive_losses = df[\"contrastive-test-loss\"].values\n",
    "    test_sup_loss = df[\"supervised-test-loss\"].values\n",
    "\n",
    "    data = []\n",
    "    for row in df.iterrows():\n",
    "        row = row[1]\n",
    "        seed = row[\"seed\"]        \n",
    "        K = row[\"neg-size\"]\n",
    "        C = row[\"num-used-classes\"]\n",
    "        supervised_loss = row[\"supervised-test-loss\"]\n",
    "        contrastive_loss = row[\"contrastive-test-loss\"]\n",
    "        new_d_row = [\n",
    "            seed, C, K, supervised_loss,\n",
    "            ours(K, C, contrastive_loss)\n",
    "        ]\n",
    "        \n",
    "        data.append(new_d_row)\n",
    "    \n",
    "    df = pd.DataFrame(data, columns=[\"seed\", \"C\", \"K\", \"sup_loss\", \"ours\"])\n",
    "    mean = df.groupby([\"C\", \"K\"]).mean().reset_index()\n",
    "    std = df.groupby([\"C\", \"K\"]).std().reset_index()    \n",
    "\n",
    "    for c in set(df[\"C\"].values):\n",
    "        _df = mean[mean[\"C\"] == c]\n",
    "        _df_std = std[std[\"C\"] == c]        \n",
    "        marker = \"o\"\n",
    "        plt.figure(figsize=(4, 3.5))\n",
    "        plt.errorbar(np.arange(len(_df[\"ours\"].values)), _df[\"ours\"].values,  yerr=_df_std[\"ours\"].values, color=\"C0\", marker=marker, capsize=5, label=\"Ours\")\n",
    "        plt.errorbar(np.arange(len(_df[\"sup_loss\"].values)), _df[\"sup_loss\"].values, ls=\"\", yerr=_df_std[\"sup_loss\"].values, color=\"black\", marker=\"x\", markersize=8, capsize=5, label=\"Supervised loss\")\n",
    "        plt.xticks(np.arange(len(_df[\"K\"].values)), _df[\"K\"].values)\n",
    "        plt.legend(fontsize=10)\n",
    "        plt.xlabel(r\"$K$\")\n",
    "\n",
    "        if dataset != \"wiki3029\":\n",
    "            plt.savefig(\"../../papers/figures/closer-look-{}-upper-bound.pdf\".format(dataset))\n",
    "\n",
    "        # show title only for notebook\n",
    "        title = \"{}, $C={}$\".format(rename[dataset], c)\n",
    "        plt.title(title)\n",
    "        plt.show()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fb105d99-5f29-47e3-987f-83963bee3b86",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "interpreter": {
   "hash": "4cd7ab41f5fca4b9b44701077e38c5ffd31fe66a6cab21e0214b68d958d0e462"
  },
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
