{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0cae861f-9053-4cac-8d71-923242d9a137",
   "metadata": {},
   "outputs": [],
   "source": [
    "from collections import defaultdict\n",
    "import wandb\n",
    "import pandas as pd"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3cd64f31",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "os.environ['http_proxy'] = \"http://10.0.0.12:8001\" \n",
    "os.environ['https_proxy'] = \"http://10.0.0.12:8001\" "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "31302b0d-8529-411b-99b6-52012e30a4e6",
   "metadata": {},
   "outputs": [],
   "source": [
    "api = wandb.Api()\n",
    "entity = api.default_entity\n",
    "project = \"vtt\"\n",
    "\n",
    "# get runs from the project\n",
    "def filter_runs(filters=None, sort=None):\n",
    "    runs = api.runs(f\"{entity}/{project}\", filters=filters)\n",
    "    runs = [\n",
    "        run\n",
    "        for run in runs\n",
    "        if (\"test/CIDEr\" in run.summary and \"model/_target_\" in run.config)\n",
    "    ]\n",
    "    if sort is not None:\n",
    "        runs = sorted(runs, key=sort)\n",
    "    print(f\"Find {len(runs)} runs in {entity}/{project}\")\n",
    "    return runs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e868149f-f93d-4844-b39c-33902cfd2ec0",
   "metadata": {},
   "outputs": [],
   "source": [
    "filters = {\"tags\": {\"$in\": [\"miss\"]}}\n",
    "runs = filter_runs(filters, sort=lambda run: run.summary[\"test/CIDEr\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "149934b9-9476-4e24-a5ab-317f5f696684",
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "import matplotlib.font_manager\n",
    "print(f\"available fonts: {sorted([f.name for f in matplotlib.font_manager.fontManager.ttflist])}\")\n",
    "\n",
    "plt.style.use('seaborn-muted')\n",
    "\n",
    "plt.rcParams[\"figure.dpi\"] = 300\n",
    "plt.rcParams[\"savefig.dpi\"] = 300\n",
    "plt.rcParams[\"savefig.format\"] = \"pdf\"\n",
    "plt.rcParams[\"savefig.bbox\"] = \"tight\"\n",
    "plt.rcParams[\"savefig.pad_inches\"] = 0.1\n",
    "\n",
    "plt.rcParams['figure.titlesize'] = 18\n",
    "plt.rcParams['axes.titlesize'] = 18\n",
    "plt.rcParams['font.family'] = 'Helvetica'\n",
    "plt.rcParams['font.size'] = 18\n",
    "\n",
    "plt.rcParams[\"lines.linewidth\"] = 2\n",
    "plt.rcParams[\"scatter.marker\"] = \"o\"\n",
    "plt.rcParams['axes.labelsize'] = 16\n",
    "plt.rcParams['axes.labelweight'] = 'bold'\n",
    "plt.rcParams['xtick.labelsize'] = 16\n",
    "plt.rcParams['ytick.labelsize'] = 16\n",
    "plt.rcParams['legend.fontsize'] = 16\n",
    "plt.rcParams['axes.linewidth'] = 2\n",
    "plt.rcParams['axes.titlepad'] = 6\n",
    "\n",
    "plt.rcParams['mathtext.fontset'] = 'dejavuserif'\n",
    "plt.rcParams['mathtext.it'] = 'serif:italic'\n",
    "plt.rcParams['lines.marker'] = \"\"\n",
    "plt.rcParams['legend.frameon'] = False"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "336ee913-1742-4a1f-ba3f-4b6356e17309",
   "metadata": {},
   "outputs": [],
   "source": [
    "results = defaultdict(list)\n",
    "models = []\n",
    "for run in runs[::-1]:\n",
    "    model_name = run.config[\"model/_target_\"].split(\".\")[-1]\n",
    "    if \"model/image_encoder\" in run.config:\n",
    "        image_encoder = run.config[\"model/image_encoder\"]\n",
    "    else:\n",
    "        image_encoder = \"ResNet152\"\n",
    "    image_encoder = image_encoder.replace(\"resnet\", \"ResNet\")\n",
    "    if image_encoder == \"ViT-L/14\":\n",
    "        image_encoder = \"CLIP\"\n",
    "    elif image_encoder == \"inception_v3\":\n",
    "        image_encoder = \"InceptionV3\"\n",
    "    if \"TTNet\" in model_name:\n",
    "        if model_name == \"TTNetDiff\":\n",
    "            if run.config[\"model/mask_ratio\"] > 0:\n",
    "                model_name = \"TTNet\"\n",
    "            else:\n",
    "                model_name = \"TTNet w/o MTM\"\n",
    "        elif model_name == \"TTNetMTM\":\n",
    "            model_name = \"TTNet w/o diff\"\n",
    "        else:\n",
    "            model_name = \"TTNet$_\\\\text{Base}$\"\n",
    "    elif image_encoder == \"CLIP\":\n",
    "        model_name += \"*\"\n",
    "    models.append(model_name)\n",
    "    results[\"Full\"].append(run.summary[\"test/CIDEr\"] * 100)\n",
    "    results[\"Randomly mask one\"].append(run.summary[\"miss_one_test/CIDEr\"] * 100)\n",
    "    results[\"Initial & Final\"].append(run.summary[\"init_fin_only_test/CIDEr\"] * 100)\n",
    "    df = pd.DataFrame(results, index=models)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "159cf19c-a6e0-4041-a008-970b3bb52ff9",
   "metadata": {},
   "outputs": [],
   "source": [
    "df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e1c03543-3bfb-4411-b18a-9633e953918b",
   "metadata": {},
   "outputs": [],
   "source": [
    "for _, row in df.iterrows():\n",
    "    plt.plot([0,1,2], row.values, 's', markersize=12, ls='-', linewidth=5, label=row.name)\n",
    "plt.xlabel(\"States\")\n",
    "plt.ylabel(\"CIDEr\")\n",
    "plt.xticks([0, 1, 2], ['full', 'randomly mask one', 'start & end only'])\n",
    "plt.legend()\n",
    "plt.savefig(\"miss.pdf\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fc4d6af3-24f0-4d75-bf46-fe25dbddae4a",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "base",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.13"
  },
  "vscode": {
   "interpreter": {
    "hash": "949777d72b0d2535278d3dc13498b2535136f6dfe0678499012e853ee9abcab1"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
