{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# PolyMNIST"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "d_size_m1 = 28**2\n",
    "d_size_m2 = 3*32**2\n",
    "d_size_m3 = 71"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(3.9183673469387754, 1.0, 43.267605633802816)"
      ]
     },
     "execution_count": 20,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "d_size_m2/d_size_m1, 1.0, d_size_m2/d_size_m3"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/opt/conda/lib/python3.6/site-packages/tqdm/auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n"
     ]
    }
   ],
   "source": [
    "from torchvision.utils import make_grid\n",
    "import numpy as np\n",
    "import torch\n",
    "import itertools\n",
    "import wandb\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "from load_models import load_model, imshow"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_metrics(name):\n",
    "    path = \"masa-su/SMVAEs/%s\" % name\n",
    "    model, train_loader, test_loader = load_model(path, \"cuda:4\", {\"eval_fid\": True, 'plot_image': False, 's_dim': False, 'gamma': 1.0})\n",
    "    \n",
    "    test_dict = model.test(0, test_loader)\n",
    "    return test_dict"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_logs(summary_dict, modality_id=[0,1,2,3,4]):\n",
    "\n",
    "    ac_list = []\n",
    "    lt_list = []\n",
    "    fid_list = []\n",
    "\n",
    "    for j in range(1, len(modality_id)+1):\n",
    "        _ac_list = []\n",
    "        _lt_list = []\n",
    "        _fid_list = []\n",
    "        for conb in itertools.combinations(modality_id, j):\n",
    "            generation_id = list(set(modality_id)-set(conb))\n",
    "\n",
    "            var = [\"x\"+str(k) for k in conb]\n",
    "            name = \"\".join(var)\n",
    "            for i in generation_id:\n",
    "                log_name = \"ac_%d_%s\"% (i, name)\n",
    "                _ac_list.append(summary_dict[log_name])\n",
    "                log_name = \"fid_%d_%s\"% (i, name)\n",
    "                _fid_list.append(summary_dict[log_name])                \n",
    "            log_name = \"lt_%s\"% (name)            \n",
    "            _lt_list.append(summary_dict[log_name])\n",
    "        \n",
    "        if j != len(modality_id):\n",
    "            ac_list.append(np.mean(_ac_list))\n",
    "            fid_list.append(np.mean(_fid_list))            \n",
    "        lt_list.append(np.mean(_lt_list))\n",
    "    \n",
    "    return ac_list, lt_list, fid_list"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "metadata": {},
   "outputs": [],
   "source": [
    "models = {\"CRMVAE\": {\"id\": [\"2404kui8\"]},\n",
    "          \"MoPoE\": {\"id\": [\"iaayegj5\", \"qflmgx9x\"]},\n",
    "          \"MVAE\": {\"id\": [\"108zk1tx\"]},\n",
    "          \"MMVAE\": {\"id\": [\"1sxksd9s\"]},\n",
    "          \"MMJSD\": {\"id\": [\"1ocz4yid\"]},\n",
    "          \"MVTCAE\": {\"id\": [\"2mh398lf\"]},\n",
    "          \"MMJSD-PoE\":{\"id\": [\"39d8p6wl\"]}}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2404kui8\n",
      "/workspace/share/multimodal_datasets/data/MMNIST\n",
      "/workspace/share/multimodal_datasets/data/MMNIST\n",
      "{'lr': 0.001, 'gpu': 1, 'beta': 1, 'data': 'PolyMNIST', 'seed': 1, 'gamma': 1.0, 'model': 'SSMVAE_P', 's_dim': False, 'z_dim': 512, 'epochs': 200, 'kl_set': 0, 'eval_fid': True, 'batch_size': 256, 'beta_smvae': 1, 'clf_epochs': 0, 'fix_weight': 0, 'forward_kl': 1, 'plot_image': False, 'test_epoch': 10, 'modality_id': [0, 1, 2, 3, 4], 'save_models': 1, 'eval_reconst': 0, 'kl_annealing': 0, 'num_sampling': 1, 'use_schedule': 0, 'weight_decay': 0, 'use_batch_gon': 0, 'fix_elbo_smvae': 0, 'optimizer_name': 'adam', 'rec_weight_all': [1.0, 1.0, 1.0, 1.0, 1.0], 'reconst_unimodal': 1, 'kl_annealing_start': -1, 'optimizer_params': {'lr': 0.001, 'weight_decay': 0}}\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  0%|          | 0/39 [00:00<?, ?it/s]"
     ]
    }
   ],
   "source": [
    "for k, v in models.items():\n",
    "    for _id in v[\"id\"]:\n",
    "        if not (_id in v):\n",
    "            print(_id)\n",
    "            try:\n",
    "                dicts = get_metrics(_id)\n",
    "                models[k][_id] = dicts\n",
    "            except:\n",
    "                print('failed', _id)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, ax = plt.subplots(1, 3, figsize=(18, 4))\n",
    "sns.set(style=\"darkgrid\")\n",
    "\n",
    "ac_idx = fid_idx = [1,2,3,4]\n",
    "lt_idx = [1,2,3,4,5]\n",
    "markers = [\".\", \"o\", \"v\", \"^\", \"<\", \">\", \"1\", \"2\", \"3\"]\n",
    "\n",
    "for i, (k, v) in enumerate(models.items()):\n",
    "    ac_all = []\n",
    "    lt_all = []\n",
    "    fid_all = []\n",
    "    ids = list(v.keys())\n",
    "    ids.remove('id')\n",
    "    if len(ids) != 0:\n",
    "        for _id in ids:\n",
    "            ac, lt, fid = get_logs(v[_id])\n",
    "            ac_all.append(ac)\n",
    "            lt_all.append(lt)\n",
    "            fid_all.append(fid)\n",
    "\n",
    "        ac_all = np.array(ac_all)\n",
    "        ac_mean = ac_all.mean(axis=0)\n",
    "        ac_std = ac_all.std(axis=0)\n",
    "\n",
    "        lt_all = np.array(lt_all)\n",
    "        lt_mean = lt_all.mean(axis=0)\n",
    "        lt_std = lt_all.std(axis=0)\n",
    "\n",
    "        fid_all = np.array(fid_all)\n",
    "        fid_mean = fid_all.mean(axis=0)\n",
    "        fid_std = fid_all.std(axis=0)    \n",
    "\n",
    "        ax[0].fill_between(lt_idx, lt_mean + lt_std, lt_mean - lt_std, alpha=0.2)\n",
    "        ax[0].plot(lt_idx, lt_mean, label=k, marker=markers[i], markeredgewidth=0)\n",
    "\n",
    "        ax[1].fill_between(ac_idx, ac_mean + ac_std, ac_mean - ac_std, alpha=0.2)    \n",
    "        ax[1].plot(ac_idx, ac_mean, label=k, marker=markers[i], markeredgewidth=0)\n",
    "\n",
    "        ax[2].fill_between(fid_idx, fid_mean + fid_std, fid_mean - fid_std, alpha=0.2)    \n",
    "        ax[2].plot(fid_idx, fid_mean, label=k, marker=markers[i], markeredgewidth=0)    \n",
    "    \n",
    "ax[2].legend(bbox_to_anchor=(1.05, 1), prop={'size': 8})\n",
    "#ax[1].legend(prop={'size': 8})\n",
    "ax[0].set_xticks(lt_idx)\n",
    "ax[1].set_xticks(ac_idx)\n",
    "ax[2].set_xticks(fid_idx)\n",
    "#ax[0].set_ylim(0.74, 1.005)\n",
    "#ax[1].set_ylim(0.2, 1.0)\n",
    "ax[1].set_ylabel(\"Cross-coherence\")\n",
    "ax[0].set_ylabel(\"Accuracy of latent representation\")\n",
    "ax[0].set_ylabel(\"FID score\")\n",
    "ax[0].set_xlabel(\"Number of input modalities\")\n",
    "ax[1].set_xlabel(\"Number of input modalities\")\n",
    "ax[2].set_xlabel(\"Number of input modalities\")\n",
    "\n",
    "#fig.tight_layout()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
