{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "ename": "KeyboardInterrupt",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[1;31mKeyboardInterrupt\u001b[0m                         Traceback (most recent call last)",
      "\u001b[1;32m<ipython-input-1-8e323d864f07>\u001b[0m in \u001b[0;36m<module>\u001b[1;34m\u001b[0m\n\u001b[0;32m      1\u001b[0m \u001b[0mget_ipython\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mrun_line_magic\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;34m'reset'\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;34m'-f'\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m      2\u001b[0m \u001b[1;32mimport\u001b[0m \u001b[0margparse\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m----> 3\u001b[1;33m \u001b[1;32mimport\u001b[0m \u001b[0mtorch\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m      4\u001b[0m \u001b[1;32mimport\u001b[0m \u001b[0mdatetime\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m      5\u001b[0m \u001b[1;32mimport\u001b[0m \u001b[0mjson\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32m~\\AppData\\Roaming\\Python\\Python37\\site-packages\\torch\\__init__.py\u001b[0m in \u001b[0;36m<module>\u001b[1;34m\u001b[0m\n\u001b[0;32m    119\u001b[0m         \u001b[0mis_loaded\u001b[0m \u001b[1;33m=\u001b[0m \u001b[1;32mFalse\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    120\u001b[0m         \u001b[1;32mif\u001b[0m \u001b[0mwith_load_library_flags\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 121\u001b[1;33m             \u001b[0mres\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mkernel32\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mLoadLibraryExW\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mdll\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;32mNone\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;36m0x00001100\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m    122\u001b[0m             \u001b[0mlast_error\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mctypes\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mget_last_error\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    123\u001b[0m             \u001b[1;32mif\u001b[0m \u001b[0mres\u001b[0m \u001b[1;32mis\u001b[0m \u001b[1;32mNone\u001b[0m \u001b[1;32mand\u001b[0m \u001b[0mlast_error\u001b[0m \u001b[1;33m!=\u001b[0m \u001b[1;36m126\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;31mKeyboardInterrupt\u001b[0m: "
     ]
    }
   ],
   "source": [
    "%reset -f\n",
    "import argparse\n",
    "import torch\n",
    "import datetime\n",
    "import json\n",
    "import yaml\n",
    "import os\n",
    "\n",
    "from main_model import CSDI_Physio\n",
    "from dataset_physio import get_dataloader\n",
    "from utils import train, evaluate\n",
    "\n",
    "parser = argparse.ArgumentParser(description=\"CSDI\")\n",
    "parser.add_argument(\"--config\", type=str, default=\"base.yaml\")\n",
    "parser.add_argument('--device', default='cuda:0', help='Device for Attack')\n",
    "parser.add_argument(\"--seed\", type=int, default=1)\n",
    "parser.add_argument(\"--testmissingratio\", type=float, default=0.5)\n",
    "parser.add_argument(\n",
    "    \"--nfold\", type=int, default=0, help=\"for 5fold test (valid value:[0-4])\"\n",
    ")\n",
    "parser.add_argument(\"--unconditional\", action=\"store_true\")\n",
    "#parser.add_argument(\"--modelfolder\", type=str, default=\"physio\")\n",
    "parser.add_argument(\"--modelfolder\", type=str, default=\"\")\n",
    "parser.add_argument(\"--nsample\", type=int, default=20)\n",
    "\n",
    "args = parser.parse_args([])\n",
    "print(args)\n",
    "\n",
    "path = \"config/\" + args.config\n",
    "with open(path, \"r\") as f:\n",
    "    config = yaml.safe_load(f)\n",
    "\n",
    "config[\"model\"][\"is_unconditional\"] = args.unconditional\n",
    "config[\"model\"][\"test_missing_ratio\"] = args.testmissingratio\n",
    "\n",
    "print(json.dumps(config, indent=4))\n",
    "\n",
    "current_time = datetime.datetime.now().strftime(\"%Y%m%d_%H%M%S\")\n",
    "#foldername = \"./save/physio_fold\" + str(args.nfold) + \"_\" + current_time + \"/\"\n",
    "foldername = \"./save/physio\" \n",
    "print('model folder:', foldername)\n",
    "os.makedirs(foldername, exist_ok=True)\n",
    "with open(foldername + \"config.json\", \"w\") as f:\n",
    "    json.dump(config, f, indent=4)\n",
    "\n",
    "train_loader, valid_loader, test_loader = get_dataloader(\n",
    "    seed=args.seed,\n",
    "    nfold=args.nfold,\n",
    "    batch_size=config[\"train\"][\"batch_size\"],\n",
    "    missing_ratio=config[\"model\"][\"test_missing_ratio\"],\n",
    ")\n",
    "\n",
    "model = CSDI_Physio(config, args.device).to(args.device)\n",
    "\n",
    "if args.modelfolder == \"\":\n",
    "    # train(\n",
    "    #     model,\n",
    "    #     config[\"train\"],\n",
    "    #     train_loader,\n",
    "    #     valid_loader=valid_loader,\n",
    "    #     foldername=foldername,\n",
    "    # )\n",
    "     train(\n",
    "        model,\n",
    "        config[\"train\"],\n",
    "        train_loader,\n",
    "        valid_loader=valid_loader,\n",
    "        foldername=foldername,\n",
    "        test_loader=test_loader,\n",
    "        nsample=args.nsample\n",
    "    )   \n",
    "    \n",
    "else:\n",
    "    model.load_state_dict(torch.load(\"./save/\" + args.modelfolder + \"/model.pth\"))\n",
    "\n",
    "#evaluate(model, test_loader, nsample=args.nsample, scaler=1, foldername=foldername)\n",
    "\n",
    "evaluate(\n",
    "    model,\n",
    "    config[\"train\"],\n",
    "    test_loader,\n",
    "    nsample=args.nsample,\n",
    "    scaler=1,\n",
    "    foldername=foldername,\n",
    ")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
