{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 40,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import sys\n",
    "import wandb\n",
    "import torch\n",
    "\n",
    "from distributions.clf import ClfImgMNIST, ClfImgSVHN, ClfText, ClfImgMMNIST, ClfImgTransMMNIST, ClfCelebAImg, ClfCelebAText\n",
    "from models.clf_model import Model\n",
    "\n",
    "# import dataset\n",
    "sys.path.append('../')\n",
    "from multimodal_datasets.SVHNMNISTDataset import SVHNMNIST\n",
    "from multimodal_datasets.CelebADataset import CelebaDataset\n",
    "from multimodal_datasets.PolyMNIST import PolyMNIST\n",
    "dir_data = os.path.join(os.getcwd().replace('jmvae_journal', ''), 'multimodal_datasets/data')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "metadata": {},
   "outputs": [],
   "source": [
    "data = \"TranslatedPolyMNIST\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 42,
   "metadata": {},
   "outputs": [],
   "source": [
    "# load dataset\n",
    "if data == \"SVHNMNIST\":\n",
    "    clf_all = [ClfImgMNIST(), ClfImgSVHN(), ClfText(num_features, text_dim)]\n",
    "    clf = clf_all[modality_id]\n",
    "    dir_clf = \"./trained_classifiers/trained_clfs_mst\"\n",
    "    clf_name = \"clf_m\"\n",
    "\n",
    "elif data==\"CelebAText\":\n",
    "    clf_all = [ClfCelebAImg(), ClfCelebAText()]\n",
    "    clf = clf_all[modality_id]\n",
    "    dir_clf = \"./trained_classifiers/trained_clfs_celeba\"\n",
    "    clf_name = \"clf_m\"    \n",
    "\n",
    "elif data==\"PolyMNIST\":\n",
    "    clf = ClfImgMMNIST()\n",
    "    dir_clf = \"./trained_classifiers/trained_clfs_polyMNIST\"\n",
    "    clf_name = \"pretrained_img_to_digit_clf_m\"\n",
    "    \n",
    "elif data == \"TranslatedPolyMNIST\":\n",
    "    clf = ClfImgTransMMNIST()\n",
    "    dir_clf = \"./trained_classifiers/trained_clfs_translatedpolyMNIST\"\n",
    "    clf_name = \"pretrained_img_to_digit_clf_m\"\n",
    "else:\n",
    "    pass"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 49,
   "metadata": {},
   "outputs": [],
   "source": [
    "def load_and_set_clf(id):\n",
    "    model = Model(clf)\n",
    "\n",
    "    path = \"masa-su/SMVAEs/%s\" % id\n",
    "    model.load_models(path)\n",
    "\n",
    "    api = wandb.Api()\n",
    "    run = api.run(path)\n",
    "\n",
    "    config = run.config\n",
    "    data = config[\"data\"]    \n",
    "    modality_id = config[\"modality_id\"]\n",
    "    \n",
    "    print(data, modality_id)\n",
    "    \n",
    "    if not os.path.exists(dir_clf):\n",
    "        os.mkdir(dir_clf)\n",
    "\n",
    "    save_dir = os.path.join(dir_clf, clf_name+str(modality_id))\n",
    "    print(save_dir)\n",
    "    torch.save(model.clf.state_dict(), save_dir)    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 50,
   "metadata": {},
   "outputs": [],
   "source": [
    "ids = [\"v6ujtmr2\", \"2j6czrg1\", \"2v9zm8kn\", \"1wshsf3r\", \"2xhujha7\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 51,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "TranslatedPolyMNIST 0\n",
      "./trained_classifiers/trained_clfs_translatedpolyMNIST/pretrained_img_to_digit_clf_m0\n",
      "TranslatedPolyMNIST 1\n",
      "./trained_classifiers/trained_clfs_translatedpolyMNIST/pretrained_img_to_digit_clf_m1\n",
      "TranslatedPolyMNIST 2\n",
      "./trained_classifiers/trained_clfs_translatedpolyMNIST/pretrained_img_to_digit_clf_m2\n",
      "TranslatedPolyMNIST 3\n",
      "./trained_classifiers/trained_clfs_translatedpolyMNIST/pretrained_img_to_digit_clf_m3\n",
      "TranslatedPolyMNIST 4\n",
      "./trained_classifiers/trained_clfs_translatedpolyMNIST/pretrained_img_to_digit_clf_m4\n"
     ]
    }
   ],
   "source": [
    "for id_ in ids:\n",
    "    load_and_set_clf(id_)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
