{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "%reset -f\n",
    "import argparse\n",
    "import torch\n",
    "import datetime\n",
    "import json\n",
    "import yaml\n",
    "import os\n",
    "\n",
    "from dataset_pm25 import get_dataloader\n",
    "from main_model import CSDI_PM25\n",
    "from utils import train, evaluate"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Namespace(config='base.yaml', device='cuda:0', modelfolder='', targetstrategy='mix', validationindex=0, nsample=30, unconditional=False)\n",
      "{\n",
      "    \"train\": {\n",
      "        \"epochs\": 200,\n",
      "        \"batch_size\": 16,\n",
      "        \"lr\": 0.001\n",
      "    },\n",
      "    \"diffusion\": {\n",
      "        \"layers\": 4,\n",
      "        \"channels\": 64,\n",
      "        \"nheads\": 8,\n",
      "        \"diffusion_embedding_dim\": 128,\n",
      "        \"beta_start\": 0.0001,\n",
      "        \"beta_end\": 0.5,\n",
      "        \"num_steps\": 10,\n",
      "        \"schedule\": \"quad\"\n",
      "    },\n",
      "    \"model\": {\n",
      "        \"is_unconditional\": false,\n",
      "        \"timeemb\": 128,\n",
      "        \"featureemb\": 16,\n",
      "        \"target_strategy\": \"mix\"\n",
      "    }\n",
      "}\n",
      "model folder: ./save/pm25_constrained/\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 303/303 [00:23<00:00, 13.14it/s, avg_epoch_loss=0.305, epoch=0]\n",
      "  0%|          | 0/6 [00:00<?, ?it/s]\n"
     ]
    },
    {
     "ename": "AttributeError",
     "evalue": "'CSDI_PM25' object has no attribute 'impute'",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mAttributeError\u001b[0m                            Traceback (most recent call last)",
      "\u001b[0;32m/tmp/ipykernel_766/3688184364.py\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m     57\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     58\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0margs\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmodelfolder\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;34m\"\"\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;31m#\"pm25_rectified\": #\"pm25\":##\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 59\u001b[0;31m     train(\n\u001b[0m\u001b[1;32m     60\u001b[0m         \u001b[0mmodel\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     61\u001b[0m         \u001b[0mconfig\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m\"train\"\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/utils.py\u001b[0m in \u001b[0;36mtrain\u001b[0;34m(model, config, train_loader, valid_loader, valid_epoch_interval, foldername, test_loader, nsample, scaler, mean_scaler)\u001b[0m\n\u001b[1;32m     90\u001b[0m         \u001b[0;32mif\u001b[0m \u001b[0mepoch_no\u001b[0m \u001b[0;34m%\u001b[0m\u001b[0;36m20\u001b[0m \u001b[0;34m==\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     91\u001b[0m             \u001b[0;31m#pass\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 92\u001b[0;31m             \u001b[0mevaluate_\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmodel\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mconfig\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mtest_loader\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnsample\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mscaler\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmean_scaler\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfoldername\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mfoldername\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     93\u001b[0m             \u001b[0;31m#evaluate_(model,config,test_loader, nsample, scaler=1, foldername=foldername)  #physio\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     94\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/utils.py\u001b[0m in \u001b[0;36mevaluate_\u001b[0;34m(model, config, test_loader, nsample, scaler, mean_scaler, foldername)\u001b[0m\n\u001b[1;32m    235\u001b[0m         \u001b[0;32mwith\u001b[0m \u001b[0mtqdm\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtest_loader\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmininterval\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m5.0\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmaxinterval\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m50.0\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0mit\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    236\u001b[0m             \u001b[0;32mfor\u001b[0m \u001b[0mbatch_no\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtest_batch\u001b[0m \u001b[0;32min\u001b[0m \u001b[0menumerate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mit\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mstart\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 237\u001b[0;31m                     \u001b[0moutput\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmodel\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mevaluate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtest_batch\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnsample\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    238\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    239\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/main_model_con.py\u001b[0m in \u001b[0;36mevaluate\u001b[0;34m(self, batch, n_samples)\u001b[0m\n\u001b[1;32m    509\u001b[0m             \u001b[0mside_info\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mget_side_info\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mobserved_tp\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcond_mask\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    510\u001b[0m             \u001b[0;31m#samples = self.impute_rk45(observed_data, cond_mask, side_info, n_samples)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 511\u001b[0;31m             \u001b[0msamples\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mimpute\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mobserved_data\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcond_mask\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mside_info\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mn_samples\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    512\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    513\u001b[0m             \u001b[0;32mfor\u001b[0m \u001b[0mi\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mrange\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcut_length\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m  \u001b[0;31m# to avoid double evaluation\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/anaconda3/lib/python3.9/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m__getattr__\u001b[0;34m(self, name)\u001b[0m\n\u001b[1;32m   1612\u001b[0m             \u001b[0;32mif\u001b[0m \u001b[0mname\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mmodules\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1613\u001b[0m                 \u001b[0;32mreturn\u001b[0m \u001b[0mmodules\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mname\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1614\u001b[0;31m         raise AttributeError(\"'{}' object has no attribute '{}'\".format(\n\u001b[0m\u001b[1;32m   1615\u001b[0m             type(self).__name__, name))\n\u001b[1;32m   1616\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mAttributeError\u001b[0m: 'CSDI_PM25' object has no attribute 'impute'"
     ]
    }
   ],
   "source": [
    "parser = argparse.ArgumentParser(description=\"CSDI\")\n",
    "parser.add_argument(\"--config\", type=str, default=\"base.yaml\")\n",
    "parser.add_argument('--device', default='cuda:0', help='Device for Attack')\n",
    "#parser.add_argument(\"--modelfolder\", type=str, default=\"pm25\")\n",
    "parser.add_argument(\"--modelfolder\", type=str, default=\"\")\n",
    "parser.add_argument(\n",
    "    \"--targetstrategy\", type=str, default=\"mix\", choices=[\"mix\", \"random\", \"historical\"]\n",
    ")\n",
    "parser.add_argument(\n",
    "    \"--validationindex\", type=int, default=0, help=\"index of month used for validation (value:[0-7])\"\n",
    ")\n",
    "parser.add_argument(\"--nsample\", type=int, default=30)# csdi: default=100)\n",
    "parser.add_argument(\"--unconditional\", action=\"store_true\")\n",
    "\n",
    "args = parser.parse_args([])\n",
    "print(args)\n",
    "\n",
    "path = \"config/\" + args.config\n",
    "with open(path, \"r\") as f:\n",
    "    config = yaml.safe_load(f)\n",
    "\n",
    "config[\"model\"][\"is_unconditional\"] = args.unconditional\n",
    "config[\"model\"][\"target_strategy\"] = args.targetstrategy\n",
    "\n",
    "print(json.dumps(config, indent=4))\n",
    "\n",
    "current_time = datetime.datetime.now().strftime(\"%Y%m%d_%H%M%S\") \n",
    "foldername = (\n",
    "    #\"./save/pm25_validationindex\" + str(args.validationindex) + \"_\" + current_time + \"/\"\n",
    "   # \n",
    "   # \"./save/pm25_rectified\" + \"/\"\n",
    "    #\"./save/pm25\" + \"/\"\n",
    "    \"./save/pm25\" + \"/\"\n",
    ")\n",
    "\n",
    "print('model folder:', foldername)\n",
    "os.makedirs(foldername, exist_ok=True)\n",
    "with open(foldername + \"config.json\", \"w\") as f:\n",
    "    json.dump(config, f, indent=4)\n",
    "\n",
    "train_loader, valid_loader, test_loader, scaler, mean_scaler = get_dataloader(\n",
    "    config[\"train\"][\"batch_size\"], device=args.device, validindex=args.validationindex\n",
    ")\n",
    "\n",
    "model = CSDI_PM25(config, args.device).to(args.device)\n",
    "\n",
    "# train(\n",
    "#         model,\n",
    "#         config[\"train\"],\n",
    "#         train_loader,\n",
    "#         valid_loader=valid_loader,\n",
    "#         foldername=foldername,\n",
    "#     )\n",
    "\n",
    "\n",
    "if args.modelfolder == \"\":\n",
    "    train(\n",
    "        model,\n",
    "        config[\"train\"],\n",
    "        train_loader,\n",
    "        valid_loader=valid_loader,\n",
    "        foldername=foldername,\n",
    "        test_loader=test_loader,\n",
    "        nsample=args.nsample,\n",
    "        scaler=scaler,\n",
    "        mean_scaler=mean_scaler,\n",
    "    )\n",
    "else:\n",
    "    model.load_state_dict(torch.load(\"./save/\" + args.modelfolder + \"/model.pth\"))\n",
    "\n",
    "    \n",
    "    \n",
    "evaluate(\n",
    "    model,\n",
    "    config[\"train\"],\n",
    "    test_loader,\n",
    "    nsample=args.nsample,\n",
    "    scaler=scaler,\n",
    "    mean_scaler=mean_scaler,\n",
    "    foldername=foldername,\n",
    ")\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
