{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "import pickle\n",
    "\n",
    "import argparse as argparse\n",
    "import numpy as np\n",
    "import torch\n",
    "from matplotlib import pyplot as plt\n",
    "import pickle\n",
    "\n",
    "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
    "\n",
    "# Oct 21, 19:00 to Oct 23,  21:00 =50 hours\n",
    "\n",
    "from ncm_mnist.ModularUtils.Experiment_Class import Experiment\n",
    "from ncm_mnist.napkin_graph import set_napkin\n",
    "from ncm_mnist.ModularUtils.ControllerModel import get_generators\n"
   ]
  },
  {
   "cell_type": "markdown",
   "source": [
    "# True Y ~P(Y)"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "outputs": [],
   "source": [
    "import os\n",
    "from torchvision.utils import save_image\n",
    "\n",
    "\n",
    "file = '../napkin_mnist/base_data/napkin_mnist_train.pkl'\n",
    "\n",
    "with open(file, 'rb') as f:\n",
    "    real_data = pickle.load(f)\n",
    "\n",
    "Y3=[]\n",
    "Y5=[]\n",
    "\n",
    "img_save_path= f\"/root/PycharmProjects/IDGEN/napkin_mnist/FID_scores/trueP_Y\"\n",
    "os.makedirs(img_save_path, exist_ok=True)\n",
    "\n",
    "idx=0\n",
    "for iter, dig in enumerate(real_data['W2a']):\n",
    "    if dig==3 or dig==5:\n",
    "        save_image(real_data['Y'][iter], f'{img_save_path}/img{idx}.png')\n",
    "        idx+=1\n",
    "\n"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "outputs": [
    {
     "data": {
      "text/plain": "torch.Size([3, 32, 32])"
     },
     "execution_count": 26,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "Y3[0].shape"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "# Baseline NCM"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "outputs": [],
   "source": [
    "file = '/root/PycharmProjects/IDGEN/napkin_mnist/baseline_samples/do_X.pkl'\n",
    "with open(file, 'rb') as f:\n",
    "    doX = pickle.load(f)"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/root/PycharmProjects/IDGEN/ncm_mnist/ModularUtils/ControllerConstants.py:15: UserWarning: nn.init.xavier_uniform is now deprecated in favor of nn.init.xavier_uniform_.\n",
      "  torch.nn.init.xavier_uniform(m.weight)\n"
     ]
    }
   ],
   "source": [
    "\n",
    "\n",
    "\n",
    "\n",
    "exp_name = 'ncmMNIST'\n",
    "\n",
    "Exp = Experiment(set_napkin,\n",
    "                 exp_name=exp_name,\n",
    "                 Temperature=1,\n",
    "                 temp_min=0.1,\n",
    "                 learning_rate=5 * 1e-4,\n",
    "                 batch_size=200,\n",
    "                 IMAGE_NOISE_DIM=3,\n",
    "                 CONF_NOISE_DIM=3,\n",
    "                 ENCODED_DIM=10,\n",
    "                 Data_intervs=[{}],\n",
    "                 num_epochs=301,\n",
    "                 new_experiment=False\n",
    "                 )\n",
    "\n",
    "Exp.load_which_models = ['rW1', 'rX', 'rY']\n",
    "\n",
    "\n"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "outputs": [
    {
     "data": {
      "text/plain": "torch.Size([3000, 3, 32, 32])"
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "doX['X'].shape"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": 47,
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 0\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "0it [00:00, ?it/s]Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).\n",
      "0it [00:00, ?it/s]\n"
     ]
    },
    {
     "data": {
      "text/plain": "<Figure size 640x480 with 1 Axes>",
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAaAAAAGdCAYAAABU0qcqAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjguMCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy81sbWrAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAd/UlEQVR4nO3df3CU5d3v8U9AsqIkCyHklwQMoFDFpG0qMUelKCkh7fGA8AdaZ4otgwMNngpVazpV1D4zsXQef/Ugdh57oJ6KWHoKHJ0Rq9GE0zZgiab4qzkkjSYOJCjPyS4EEzjJdf7ocU9XieRKdvlml/dr5pph7/uba7+3N+bDvffutSnOOScAAM6yUdYNAADOTQQQAMAEAQQAMEEAAQBMEEAAABMEEADABAEEADBBAAEATJxn3cBn9ff369ChQ0pLS1NKSop1OwAAT845HTt2THl5eRo1auDrnBEXQIcOHVJ+fr51GwCAYWpvb9fkyZMH3B+3l+A2btyoiy++WOeff75KSkr0+uuvD+rn0tLS4tUSAOAsOtPv87gE0HPPPad169Zp/fr1euONN1RUVKTy8nIdOXLkjD/Ly24AkBzO+PvcxcGcOXNcZWVl5HFfX5/Ly8tz1dXVZ/zZUCjkJDEYDAYjwUcoFPrC3/cxvwI6efKkGhoaVFZWFtk2atQolZWVqb6+/nP1vb29CofDUQMAkPxiHkAff/yx+vr6lJ2dHbU9OztbHR0dn6uvrq5WMBiMDN6AAADnBvPPAVVVVSkUCkVGe3u7dUsAgLMg5m/DzszM1OjRo9XZ2Rm1vbOzUzk5OZ+rDwQCCgQCsW4DADDCxfwKKDU1VcXFxaqpqYls6+/vV01NjUpLS2P9dACABBWXD6KuW7dOy5cv19e+9jXNmTNHjz76qLq7u/Xd7343Hk8HAEhAcQmgZcuW6aOPPtJ9992njo4OffnLX9bu3bs/98YEAMC5K8U556yb+GfhcFjBYNC6DQDAMIVCIaWnpw+43/xdcACAcxMBBAAwQQABAEwQQAAAEwQQAMAEAQQAMEEAAQBMEEAAABMEEADABAEEADBBAAEATBBAAAATBBAAwAQBBAAwQQABAEwQQAAAEwQQAMAEAQQAMEEAAQBMEEAAABMEEADABAEEADBBAAEATBBAAAATBBAAwAQBBAAwQQABAEwQQAAAEwQQAMAEAQQAMEEAAQBMEEAAABMEEADABAEEADBBAAEATBBAAAATBBAAwAQBBAAwQQABAEwQQAAAEwQQAMAEAQQAMEEAAQBMEEAAABMEEADABAEEADBBAAEATMQ8gO6//36lpKREjVmzZsX6aQAACe68eEx6+eWX65VXXvn/T3JeXJ4GAJDA4pIM5513nnJycuIxNQAgScTlHtDBgweVl5enadOm6ZZbblFbW9uAtb29vQqHw1EDAJD8Yh5AJSUl2rJli3bv3q1NmzaptbVV1157rY4dO3ba+urqagWDwcjIz8+PdUsAgBEoxTnn4vkEXV1dmjp1qh5++GGtWLHic/t7e3vV29sbeRwOhwkhAEgCoVBI6enpA+6P+7sDxo8fr0svvVTNzc2n3R8IBBQIBOLdBgBghIn754COHz+ulpYW5ebmxvupAAAJJOYBdOedd6qurk7vv/++/vznP+vGG2/U6NGjdfPNN8f6qQAACSzmL8F9+OGHuvnmm3X06FFNmjRJ11xzjfbu3atJkybF+qmQjHZ41l/sUfsVz7kBxFXMA2jbtm2xnhIAkIRYCw4AYIIAAgCYIIAAACYIIACACQIIAGCCAAIAmCCAAAAmCCAAgAkCCABgggACAJiI+9cxAJrgUbs4Xk0AGGm4AgIAmCCAAAAmCCAAgAkCCABgggACAJgggAAAJgggAIAJAggAYIIAAgCYIIAAACZYigfx95sZHsUnPCc/5FkPYKTgCggAYIIAAgCYIIAAACYIIACACQIIAGCCAAIAmCCAAAAmCCAAgAkCCABgggACAJgggAAAJlgLDkMwy6/8m/M8irf6ze2zzByAEYUrIACACQIIAGCCAAIAmCCAAAAmCCAAgAkCCABgggACAJgggAAAJgggAIAJAggAYIIAAgCYYC04DMHPPOv/PPjSv4f9pm7xKx8pvvI1v/oTYwZf21TvNzdghSsgAIAJ7wDas2ePbrjhBuXl5SklJUU7d+6M2u+c03333afc3FyNHTtWZWVlOnjwYKz6BQAkCe8A6u7uVlFRkTZu3Hja/Rs2bNDjjz+uJ598Uvv27dOFF16o8vJy9fT0DLtZAEDy8L4HVFFRoYqKitPuc87p0Ucf1U9+8hMtWrRIkvT0008rOztbO3fu1E033TS8bgEASSOm94BaW1vV0dGhsrKyyLZgMKiSkhLV15/+zmhvb6/C4XDUAAAkv5gGUEdHhyQpOzs7ant2dnZk32dVV1crGAxGRn5+fixbAgCMUObvgquqqlIoFIqM9vZ265YAAGdBTAMoJydHktTZ2Rm1vbOzM7LvswKBgNLT06MGACD5xTSACgoKlJOTo5qamsi2cDisffv2qbS0NJZPBQBIcN7vgjt+/Liam5sjj1tbW9XY2KiMjAxNmTJFd9xxh/7lX/5Fl1xyiQoKCnTvvfcqLy9PixcvjmXfAIAE5x1A+/fv13XXXRd5vG7dOknS8uXLtWXLFt19993q7u7Wbbfdpq6uLl1zzTXavXu3zj///Nh1jRjzfOPHrCLP+f/r4Euv8Jx6hPj6DL/6wtv96g/898HXNvlNDZjxDqB58+bJOTfg/pSUFD344IN68MEHh9UYACC5mb8LDgBwbiKAAAAmCCAAgAkCCABgggACAJgggAAAJgggAIAJAggAYIIAAgCYIIAAACa8l+JBMvquX/ktf/Or37hr8LUn/KaOp+s9aq+8w2/urn/3q3cf+9XjLCv3qK30nPs/edYnEK6AAAAmCCAAgAkCCABgggACAJgggAAAJgggAIAJAggAYIIAAgCYIIAAACYIIACACZbigaQ+v/K3XvGr/61febzc5lmf953B13Z5LiH0UbtffZileM6qCz3rp/+XwdcemOE5eZFH7V895zbGFRAAwAQBBAAwQQABAEwQQAAAEwQQAMAEAQQAMEEAAQBMEEAAABMEEADABAEEADBBAAEATLAWHCS961Wd+dc6r/qveNS+5zWzdJdH7fSZfnMf81gi7y+ey+O5k3716f/Lrx7D8x+/71ff4rO+25/85k609d18cAUEADBBAAEATBBAAAATBBAAwAQBBAAwQQABAEwQQAAAEwQQAMAEAQQAMEEAAQBMsBQPJP3Nq/rSpn/3qs/yqL3Ba2bp20WDr93b7zf3O88MvrbLb2pN96y/xKN2j+fc54LZnvXzfuZX7/V/0Hq/uZMZV0AAABMEEADAhHcA7dmzRzfccIPy8vKUkpKinTt3Ru2/9dZblZKSEjUWLlwYq34BAEnCO4C6u7tVVFSkjRs3DlizcOFCHT58ODKeffbZYTUJAEg+3m9CqKioUEVFxRfWBAIB5eTkDLkpAEDyi8s9oNraWmVlZWnmzJlavXq1jh49OmBtb2+vwuFw1AAAJL+YB9DChQv19NNPq6amRj/72c9UV1eniooK9fWd/uslq6urFQwGIyM/Pz/WLQEARqCYfw7opptuivz5iiuuUGFhoaZPn67a2lrNnz//c/VVVVVat25d5HE4HCaEAOAcEPe3YU+bNk2ZmZlqbm4+7f5AIKD09PSoAQBIfnEPoA8//FBHjx5Vbm5uvJ8KAJBAvF+CO378eNTVTGtrqxobG5WRkaGMjAw98MADWrp0qXJyctTS0qK7775bM2bMUHl5eUwbBwAkNu8A2r9/v6677rrI40/v3yxfvlybNm3SgQMH9Otf/1pdXV3Ky8vTggUL9NOf/lSBQCB2XSPG3vOq9lxSTRM8arPz/OYee+nga1u3+839ukftXL+pdV2KX322x23RDW1+cycqnzvFK/7gN3f7OL/6v272KK7xmzuZeQfQvHnz5JwbcP9LL700rIYAAOcG1oIDAJgggAAAJgggAIAJAggAYIIAAgCYIIAAACYIIACACQIIAGCCAAIAmCCAAAAmYv59QBgpJnrUZnrN/NElTV71HScGX9tc4DW1/rfHulp/9JtaPsuH3e0593+Y51ff5Nv8OWDjhsHX9n/Db+7/8ZFfvf6zZz0kcQUEADBCAAEATBBAAAATBBAAwAQBBAAwQQABAEwQQAAAEwQQAMAEAQQAMEEAAQBMsBRP0rrGozbHa+ZQ31iv+vdSGgddO73Ba2r9z08GX/uc39ReJvn+QMCv/NAp3ydIPPdf4lc/duXga//tA7+5X7vVr17HPeshiSsgAIARAggAYIIAAgCYIIAAACYIIACACQIIAGCCAAIAmCCAAAAmCCAAgAkCCABgggACAJhgLbik9ZFH7VyvmT/+e7ZfvboGXXtc73vN/W9e1fHzgmd94W6/+niuYxdPl3vUXrjKb+5Xjw6+9rf3+82tWs96DAlXQAAAEwQQAMAEAQQAMEEAAQBMEEAAABMEEADABAEEADBBAAEATBBAAAATBBAAwARL8SStP3vUBj3nnuFZHxp05QeeM48UT3jWF3jWj5R/KX7Xs/6j+YOvPTD4vyaSpLdXehS/5jc3zo6R8vcaAHCO8Qqg6upqXXnllUpLS1NWVpYWL16spqamqJqenh5VVlZq4sSJGjdunJYuXarOzs6YNg0ASHxeAVRXV6fKykrt3btXL7/8sk6dOqUFCxaou7s7UrN27Vo9//zz2r59u+rq6nTo0CEtWbIk5o0DABKb1z2g3buj15DfsmWLsrKy1NDQoLlz5yoUCulXv/qVtm7dquuvv16StHnzZn3pS1/S3r17ddVVV8WucwBAQhvWPaBQ6B93DTMyMiRJDQ0NOnXqlMrKyiI1s2bN0pQpU1RfX3/aOXp7exUOh6MGACD5DTmA+vv7dccdd+jqq6/W7NmzJUkdHR1KTU3V+PHjo2qzs7PV0dFx2nmqq6sVDAYjIz8/f6gtAQASyJADqLKyUm+//ba2bds2rAaqqqoUCoUio729fVjzAQASw5A+B7RmzRq98MIL2rNnjyZPnhzZnpOTo5MnT6qrqyvqKqizs1M5OTmnnSsQCCgQCAylDQBAAvO6AnLOac2aNdqxY4deffVVFRREf5yuuLhYY8aMUU1NTWRbU1OT2traVFpaGpuOAQBJwesKqLKyUlu3btWuXbuUlpYWua8TDAY1duxYBYNBrVixQuvWrVNGRobS09N1++23q7S0lHfAAQCieAXQpk2bJEnz5s2L2r5582bdeuutkqRHHnlEo0aN0tKlS9Xb26vy8nI98YTvQiUAgGSX4pxz1k38s3A4rGDQd20yILH80LM+16M2zXPu/MsHX9s9y2/u3380+Nr6PX5zv+9XDgOhUEjp6ekD7mctOACACQIIAGCCAAIAmCCAAAAmCCAAgAkCCABgggACAJgggAAAJgggAIAJAggAYGJIX8cAYHj+1bO+ePrga78y8MonpxWYPfja46f/XskB7fVYXud9v6mRBLgCAgCYIIAAACYIIACACQIIAGCCAAIAmCCAAAAmCCAAgAkCCABgggACAJgggAAAJgggAIAJ1oIDTJzvVd3Q0jPo2q5pfp38n/82+NoP/KZWq2c9zi1cAQEATBBAAAATBBAAwAQBBAAwQQABAEwQQAAAEwQQAMAEAQQAMEEAAQBMEEAAABMsxQOYGPzSOr5a/u5ZH582gDPiCggAYIIAAgCYIIAAACYIIACACQIIAGCCAAIAmCCAAAAmCCAAgAkCCABgggACAJgggAAAJgggAIAJAggAYMIrgKqrq3XllVcqLS1NWVlZWrx4sZqamqJq5s2bp5SUlKixatWqmDYNAEh8XgFUV1enyspK7d27Vy+//LJOnTqlBQsWqLu7O6pu5cqVOnz4cGRs2LAhpk0DABKf1/cB7d69O+rxli1blJWVpYaGBs2dOzey/YILLlBOTk5sOgQAJKVh3QMKhUKSpIyMjKjtzzzzjDIzMzV79mxVVVXpxIkTA87R29urcDgcNQAA5wA3RH19fe5b3/qWu/rqq6O2//KXv3S7d+92Bw4ccL/5zW/cRRdd5G688cYB51m/fr2TxGAwGIwkG6FQ6AtzZMgBtGrVKjd16lTX3t7+hXU1NTVOkmtubj7t/p6eHhcKhSKjvb3d/D8ag8FgMIY/zhRAXveAPrVmzRq98MIL2rNnjyZPnvyFtSUlJZKk5uZmTZ8+/XP7A4GAAoHAUNoAACQwrwByzun222/Xjh07VFtbq4KCgjP+TGNjoyQpNzd3SA0CAJKTVwBVVlZq69at2rVrl9LS0tTR0SFJCgaDGjt2rFpaWrR161Z985vf1MSJE3XgwAGtXbtWc+fOVWFhYVwOAACQoHzu+2iA1/k2b97snHOura3NzZ0712VkZLhAIOBmzJjh7rrrrjO+DvjPQqGQ+euWDAaDwRj+ONPv/pT/FywjRjgcVjAYtG4DADBMoVBI6enpA+5nLTgAgAkCCABgggACAJgggAAAJgggAIAJAggAYIIAAgCYIIAAACYIIACACQIIAGCCAAIAmCCAAAAmCCAAgAkCCABgggACAJgggAAAJgggAIAJAggAYIIAAgCYIIAAACYIIACACQIIAGCCAAIAmCCAAAAmCCAAgAkCCABgggACAJgggAAAJgggAIAJAggAYIIAAgCYIIAAACYIIACACQIIAGCCAAIAmCCAAAAmCCAAgAkCCABgggACAJgggAAAJgggAIAJAggAYIIAAgCYIIAAACYIIACACQIIAGCCAAIAmPAKoE2bNqmwsFDp6elKT09XaWmpXnzxxcj+np4eVVZWauLEiRo3bpyWLl2qzs7OmDcNAEh8XgE0efJkPfTQQ2poaND+/ft1/fXXa9GiRXrnnXckSWvXrtXzzz+v7du3q66uTocOHdKSJUvi0jgAIMG5YZowYYJ76qmnXFdXlxszZozbvn17ZN97773nJLn6+vpBzxcKhZwkBoPBYCT4CIVCX/j7fsj3gPr6+rRt2zZ1d3ertLRUDQ0NOnXqlMrKyiI1s2bN0pQpU1RfXz/gPL29vQqHw1EDAJD8vAPorbfe0rhx4xQIBLRq1Srt2LFDl112mTo6OpSamqrx48dH1WdnZ6ujo2PA+aqrqxUMBiMjPz/f+yAAAInHO4BmzpypxsZG7du3T6tXr9by5cv17rvvDrmBqqoqhUKhyGhvbx/yXACAxHGe7w+kpqZqxowZkqTi4mL95S9/0WOPPaZly5bp5MmT6urqiroK6uzsVE5OzoDzBQIBBQIB/84BAAlt2J8D6u/vV29vr4qLizVmzBjV1NRE9jU1NamtrU2lpaXDfRoAQJLxugKqqqpSRUWFpkyZomPHjmnr1q2qra3VSy+9pGAwqBUrVmjdunXKyMhQenq6br/9dpWWluqqq66KV/8AgATlFUBHjhzRd77zHR0+fFjBYFCFhYV66aWX9I1vfEOS9Mgjj2jUqFFaunSpent7VV5erieeeCIujQMAEluKc85ZN/HPwuGwgsGgdRsAgGEKhUJKT08fcD9rwQEATBBAAAATBBAAwAQBBAAwQQABAEwQQAAAEwQQAMAEAQQAMEEAAQBMjLgAGmELMwAAhuhMv89HXAAdO3bMugUAQAyc6ff5iFsLrr+/X4cOHVJaWppSUlIi28PhsPLz89Xe3v6FawslOo4zeZwLxyhxnMkmFsfpnNOxY8eUl5enUaMGvs7x/kK6eBs1apQmT5484P709PSkPvmf4jiTx7lwjBLHmWyGe5yDWVR6xL0EBwA4NxBAAAATCRNAgUBA69evVyAQsG4lrjjO5HEuHKPEcSabs3mcI+5NCACAc0PCXAEBAJILAQQAMEEAAQBMEEAAABMJE0AbN27UxRdfrPPPP18lJSV6/fXXrVuKqfvvv18pKSlRY9asWdZtDcuePXt0ww03KC8vTykpKdq5c2fUfuec7rvvPuXm5mrs2LEqKyvTwYMHbZodhjMd56233vq5c7tw4UKbZoeourpaV155pdLS0pSVlaXFixerqakpqqanp0eVlZWaOHGixo0bp6VLl6qzs9Oo46EZzHHOmzfvc+dz1apVRh0PzaZNm1RYWBj5sGlpaalefPHFyP6zdS4TIoCee+45rVu3TuvXr9cbb7yhoqIilZeX68iRI9atxdTll1+uw4cPR8Yf//hH65aGpbu7W0VFRdq4ceNp92/YsEGPP/64nnzySe3bt08XXnihysvL1dPTc5Y7HZ4zHackLVy4MOrcPvvss2exw+Grq6tTZWWl9u7dq5dfflmnTp3SggUL1N3dHalZu3atnn/+eW3fvl11dXU6dOiQlixZYti1v8EcpyStXLky6nxu2LDBqOOhmTx5sh566CE1NDRo//79uv7667Vo0SK98847ks7iuXQJYM6cOa6ysjLyuK+vz+Xl5bnq6mrDrmJr/fr1rqioyLqNuJHkduzYEXnc39/vcnJy3M9//vPItq6uLhcIBNyzzz5r0GFsfPY4nXNu+fLlbtGiRSb9xMuRI0ecJFdXV+ec+8e5GzNmjNu+fXuk5r333nOSXH19vVWbw/bZ43TOua9//evuBz/4gV1TcTJhwgT31FNPndVzOeKvgE6ePKmGhgaVlZVFto0aNUplZWWqr6837Cz2Dh48qLy8PE2bNk233HKL2trarFuKm9bWVnV0dESd12AwqJKSkqQ7r5JUW1urrKwszZw5U6tXr9bRo0etWxqWUCgkScrIyJAkNTQ06NSpU1Hnc9asWZoyZUpCn8/PHuennnnmGWVmZmr27NmqqqrSiRMnLNqLib6+Pm3btk3d3d0qLS09q+dyxC1G+lkff/yx+vr6lJ2dHbU9Oztbf/vb34y6ir2SkhJt2bJFM2fO1OHDh/XAAw/o2muv1dtvv620tDTr9mKuo6NDkk57Xj/dlywWLlyoJUuWqKCgQC0tLfrxj3+siooK1dfXa/To0dbteevv79cdd9yhq6++WrNnz5b0j/OZmpqq8ePHR9Um8vk83XFK0re//W1NnTpVeXl5OnDggH70ox+pqalJv//97w279ffWW2+ptLRUPT09GjdunHbs2KHLLrtMjY2NZ+1cjvgAOldUVFRE/lxYWKiSkhJNnTpVv/3tb7VixQrDzjBcN910U+TPV1xxhQoLCzV9+nTV1tZq/vz5hp0NTWVlpd5+++2Ev0d5JgMd52233Rb58xVXXKHc3FzNnz9fLS0tmj59+tluc8hmzpypxsZGhUIh/e53v9Py5ctVV1d3VnsY8S/BZWZmavTo0Z97B0ZnZ6dycnKMuoq/8ePH69JLL1Vzc7N1K3Hx6bk7186rJE2bNk2ZmZkJeW7XrFmjF154Qa+99lrU16bk5OTo5MmT6urqiqpP1PM50HGeTklJiSQl3PlMTU3VjBkzVFxcrOrqahUVFemxxx47q+dyxAdQamqqiouLVVNTE9nW39+vmpoalZaWGnYWX8ePH1dLS4tyc3OtW4mLgoIC5eTkRJ3XcDisffv2JfV5laQPP/xQR48eTahz65zTmjVrtGPHDr366qsqKCiI2l9cXKwxY8ZEnc+mpia1tbUl1Pk803GeTmNjoyQl1Pk8nf7+fvX29p7dcxnTtzTEybZt21wgEHBbtmxx7777rrvtttvc+PHjXUdHh3VrMfPDH/7Q1dbWutbWVvenP/3JlZWVuczMTHfkyBHr1obs2LFj7s0333Rvvvmmk+Qefvhh9+abb7oPPvjAOefcQw895MaPH+927drlDhw44BYtWuQKCgrcJ598Yty5ny86zmPHjrk777zT1dfXu9bWVvfKK6+4r371q+6SSy5xPT091q0P2urVq10wGHS1tbXu8OHDkXHixIlIzapVq9yUKVPcq6++6vbv3+9KS0tdaWmpYdf+znSczc3N7sEHH3T79+93ra2tbteuXW7atGlu7ty5xp37ueeee1xdXZ1rbW11Bw4ccPfcc49LSUlxf/jDH5xzZ+9cJkQAOefcL37xCzdlyhSXmprq5syZ4/bu3WvdUkwtW7bM5ebmutTUVHfRRRe5ZcuWuebmZuu2huW1115zkj43li9f7pz7x1ux7733Xpedne0CgYCbP3++a2pqsm16CL7oOE+cOOEWLFjgJk2a5MaMGeOmTp3qVq5cmXD/eDrd8UlymzdvjtR88skn7vvf/76bMGGCu+CCC9yNN97oDh8+bNf0EJzpONva2tzcuXNdRkaGCwQCbsaMGe6uu+5yoVDItnFP3/ve99zUqVNdamqqmzRpkps/f34kfJw7e+eSr2MAAJgY8feAAADJiQACAJgggAAAJgggAIAJAggAYIIAAgCYIIAAACYIIACACQIIAGCCAAIAmCCAAAAmCCAAgIn/C4MzU103v3P5AAAAAElFTkSuQmCC"
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "from tqdm import tqdm\n",
    "import os\n",
    "from torchvision.utils import save_image\n",
    "\n",
    "# Epochs=[0, 50,100, 150,200,250,295]\n",
    "Epochs=[0]\n",
    "for epoch in Epochs:\n",
    "    print('Epoch:', epoch)\n",
    "\n",
    "    gfile = f'/root/PycharmProjects/IDGEN/ncm_mnist/SAVED_EXPERIMENTS/ncmFixedTemp/gen_checkpoints/epoch{epoch:003}.pth'\n",
    "    checkpoint = torch.load(gfile, map_location=\"cuda\")\n",
    "    label_generators, optimizersMech = get_generators(Exp, Exp.load_which_models)\n",
    "\n",
    "    for label in label_generators:\n",
    "        label_generators[label].load_state_dict(checkpoint[label + \"state_dict\"])\n",
    "\n",
    "\n",
    "    img_save_path= f\"/root/PycharmProjects/IDGEN/napkin_mnist/FID_scores/ncm{epoch:003}\"\n",
    "    os.makedirs(img_save_path, exist_ok=True)\n",
    "\n",
    "    for iter,img in tqdm(enumerate(doX['X'])):\n",
    "        curX= img.unsqueeze(0).to(device)\n",
    "        image_noise = torch.randn(curX.shape[0], Exp.IMAGE_NOISE_DIM+ Exp.CONF_NOISE_DIM  ,Exp.IMAGE_SIZE, Exp.IMAGE_SIZE).to(Exp.DEVICE)\n",
    "        input = torch.cat([image_noise, curX], dim=1).to(Exp.DEVICE)\n",
    "        genY= label_generators['Y'](input)\n",
    "\n",
    "        curimg1= genY[0].detach().cpu().numpy()\n",
    "        curimg1= np.transpose(curimg1)\n",
    "        plt.imshow(curimg1)\n",
    "\n",
    "        #\n",
    "        # curimg2= img.detach().cpu().numpy()\n",
    "        # curimg2= np.transpose(curimg2)\n",
    "        # plt.imshow(curimg2)\n",
    "        break\n",
    "\n",
    "    break\n",
    "\n",
    "        # save_image(genY.squeeze(), f'{img_save_path}/img{iter}.png')\n"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "outputs": [
    {
     "data": {
      "text/plain": "<matplotlib.image.AxesImage at 0x7fc9f63271f0>"
     },
     "execution_count": 35,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "text/plain": "<Figure size 640x480 with 1 Axes>",
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAaAAAAGdCAYAAABU0qcqAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjguMCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy81sbWrAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAeYUlEQVR4nO3dfXCU9d3v8c/ykBUluyE85KEEDKBQxdC7VGJulaKkQHpubxA6gw/3NLSMDjR4CqlV01FR27lj6RyfOoh/eI4c7xGxdARunSNWgwljDVgiFB9zgJMWvCFBmWYXgiyU/M4fTrddSSC/ZJdvNnm/Zn4z7HV997ffa66ZfLh2r/1twDnnBADABTbAugEAQP9EAAEATBBAAAATBBAAwAQBBAAwQQABAEwQQAAAEwQQAMDEIOsGvqq9vV2HDh1SZmamAoGAdTsAAE/OOR07dkz5+fkaMKDz65xeF0CHDh1SQUGBdRsAgB46ePCgRo8e3en+lL0Ft3r1al166aW66KKLVFxcrHfffbdLz8vMzExVSwCAC+h8f89TEkAvvfSSKisrtXLlSr333nuaMmWKZs+erSNHjpz3ubztBgB9w3n/nrsUmDZtmquoqIg/PnPmjMvPz3fV1dXnfW4kEnGSGAwGg5HmIxKJnPPvfdKvgE6dOqWGhgaVlpbGtw0YMEClpaWqr68/qz4WiykajSYMAEDfl/QA+vzzz3XmzBnl5OQkbM/JyVFzc/NZ9dXV1QqHw/HBDQgA0D+Yfw+oqqpKkUgkPg4ePGjdEgDgAkj6bdgjRozQwIED1dLSkrC9paVFubm5Z9UHg0EFg8FktwEA6OWSfgWUkZGhqVOnqqamJr6tvb1dNTU1KikpSfbLAQDSVEq+iFpZWany8nJ961vf0rRp0/TEE0+ora1NP/jBD1LxcgCANJSSAFq4cKE+++wzPfjgg2pubtY3vvENbdmy5awbEwAA/VfAOeesm/hH0WhU4XDYug0AQA9FIhGFQqFO95vfBQcA6J8IIACACQIIAGCCAAIAmCCAAAAmCCAAgAkCCABgggACAJgggAAAJgggAIAJAggAYIIAAgCYIIAAACYIIACACQIIAGCCAAIAmCCAAAAmCCAAgAkCCABgggACAJgggAAAJgggAIAJAggAYIIAAgCYIIAAACYIIACACQIIAGCCAAIAmCCAAAAmCCAAgAkCCABgggACAJgggAAAJgggAIAJAggAYIIAAgCYIIAAACYIIACACQIIAGCCAAIAmCCAAAAmCCAAgAkCCABgggACAJgggAAAJgggAIAJAggAYCLpAfTQQw8pEAgkjEmTJiX7ZQAAaW5QKia98sor9eabb/79RQal5GUAAGksJckwaNAg5ebmpmJqAEAfkZLPgPbu3av8/HyNGzdOt99+uw4cONBpbSwWUzQaTRgAgL4v6QFUXFystWvXasuWLVqzZo2ampp0/fXX69ixYx3WV1dXKxwOx0dBQUGyWwIA9EIB55xL5Qu0trZq7Nixeuyxx7R48eKz9sdiMcVisfjjaDRKCAFAHxCJRBQKhTrdn/K7A7KysnT55Zdr3759He4PBoMKBoOpbgMA0Muk/HtAx48f1/79+5WXl5fqlwIApJGkB9Ddd9+turo6/elPf9I777yjm2++WQMHDtStt96a7JcCAKSxpL8F9+mnn+rWW2/V0aNHNXLkSF133XXavn27Ro4cmeyXAgCksZTfhOArGo0qHA5btwEA6KHz3YTAWnAAABMEEADABAEEADBBAAEATBBAAAATBBAAwAQBBAAwQQABAEwQQAAAEwQQAMBEyn+OAUAHfJdG/OnXu1779X/ym3vlq12vfY9fLEbycAUEADBBAAEATBBAAAATBBAAwAQBBAAwQQABAEwQQAAAEwQQAMAEAQQAMEEAAQBMsBQP+o9bPetv9qjNzfCb+59/7lc/8Fqfyf3m/pdA12v3PO4395RKv3r0K1wBAQBMEEAAABMEEADABAEEADBBAAEATBBAAAATBBAAwAQBBAAwQQABAEwQQAAAEwQQAMAEa8Eh9f7Fo3au59z/5lF7kefcPurb/er/8H/96vNCXa8dm+U3t67semnRCr+p/3Vz12v/s85vbqQ9roAAACYIIACACQIIAGCCAAIAmCCAAAAmCCAAgAkCCABgggACAJgggAAAJgggAIAJAggAYIK14CB9zbP+Uc/60R61F3vO/ReP2i2ec/+HR+2Hf/Wb+4YDnvXjul5b9Inf3CUea8H52vXH1M2NtMcVEADAhHcAbdu2TTfddJPy8/MVCAS0adOmhP3OOT344IPKy8vTkCFDVFpaqr179yarXwBAH+EdQG1tbZoyZYpWr17d4f5Vq1bpqaee0jPPPKMdO3bokksu0ezZs3Xy5MkeNwsA6Du8PwMqKytTWVlZh/ucc3riiSd0//33a+7cL3/Y5fnnn1dOTo42bdqkW265pWfdAgD6jKR+BtTU1KTm5maVlpbGt4XDYRUXF6u+vr7D58RiMUWj0YQBAOj7khpAzc3NkqScnJyE7Tk5OfF9X1VdXa1wOBwfBQUFyWwJANBLmd8FV1VVpUgkEh8HDx60bgkAcAEkNYByc3MlSS0tLQnbW1pa4vu+KhgMKhQKJQwAQN+X1AAqLCxUbm6uampq4tui0ah27NihkpKSZL4UACDNed8Fd/z4ce3bty/+uKmpSbt371Z2drbGjBmj5cuX6xe/+IUuu+wyFRYW6oEHHlB+fr7mzZuXzL4BAGnOO4B27typG264If64srJSklReXq61a9fqnnvuUVtbm+688061trbquuuu05YtW3TRRRclr+t+yPdSdeKKrtd+fMpz8rWe9TXnL+l3XhroV//JJV2vLW/zm7vk3a7Xrvdcz+hgq189+hXvAJoxY4acc53uDwQCeuSRR/TII4/0qDEAQN9mfhccAKB/IoAAACYIIACACQIIAGCCAAIAmCCAAAAmCCAAgAkCCABgggACAJgggAAAJgLuXOvqGIhGowqHw9Zt9Dqf5/vVH+v41y86VPGe39z/x6+8XyjSVK/6f9P3PV/hn7pcuWnS514zv9NS3fXiv/zBa270b5FI5Jw/scMVEADABAEEADBBAAEATBBAAAATBBAAwAQBBAAwQQABAEwQQAAAEwQQAMAEAQQAMDHIugF0zbhDfvWnPeq/8Ju63/gfHrX/TTd7zf1f8ltu6pf6dZdr3/lkg9fcgBWugAAAJgggAIAJAggAYIIAAgCYIIAAACYIIACACQIIAGCCAAIAmCCAAAAmCCAAgAkCCABggrXg0kTUuoFe6n961M5MWRfS/1OrV/1bavOq/51Y3w19D1dAAAATBBAAwAQBBAAwQQABAEwQQAAAEwQQAMAEAQQAMEEAAQBMEEAAABMEEADARMA556yb+EfRaFThcNi6DSTRLR61L6asC+kdz/r/8qiN6jqvuf+qMq/6Uyrocu1f9KnX3Ec0sMu1q7Xda27pc4/aiOfcezzrcaFFIhGFQqFO93MFBAAwQQABAEx4B9C2bdt00003KT8/X4FAQJs2bUrYv2jRIgUCgYQxZ86cZPULAOgjvAOora1NU6ZM0erVqzutmTNnjg4fPhwfL76Yynf2AQDpyPv3gMrKylRWdu4PUIPBoHJzc7vdFACg70vJZ0C1tbUaNWqUJk6cqKVLl+ro0aOd1sZiMUWj0YQBAOj7kh5Ac+bM0fPPP6+amhr98pe/VF1dncrKynTmzJkO66urqxUOh+OjoKDrt5sCANJX0n+S+5Zb/v6tj6uuukpFRUUaP368amtrNXPm2T+KXFVVpcrKyvjjaDRKCAFAP5Dy27DHjRunESNGaN++fR3uDwaDCoVCCQMA0PelPIA+/fRTHT16VHl5eal+KQBAGvF+C+748eMJVzNNTU3avXu3srOzlZ2drYcfflgLFixQbm6u9u/fr3vuuUcTJkzQ7Nmzk9o4ACC9ea8FV1tbqxtuuOGs7eXl5VqzZo3mzZunXbt2qbW1Vfn5+Zo1a5Z+/vOfKycnp0vzsxZc39Pxm68dG+c59289at/3nPtrHrVBz7mH6y6v+ln67x69BLzmPq1hXa79pTK85n5AQzyqu74m3Zf2e9RO8JwbyXC+teC8r4BmzJihc2XW66+/7jslAKAfYi04AIAJAggAYIIAAgCYIIAAACYIIACACQIIAGCCAAIAmCCAAAAmCCAAgAkCCABgIum/BwR81f/yqL3Jc26f1cNGeM7ts0jibs+539Kvver/Xe90ufaEjnjN/Vc1d7k2Qz/wmls6e93Izt1y/pIE4z1q/9Vz7v/0rEd3cAUEADBBAAEATBBAAAATBBAAwAQBBAAwQQABAEwQQAAAEwQQAMAEAQQAMEEAAQBMBJxzPiuOpFw0GlU4HLZuA0BS+Cyv47c8kd/iSn/0nPsbnvXoSCQSUSgU6nQ/V0AAABMEEADABAEEADBBAAEATBBAAAATBBAAwAQBBAAwQQABAEwQQAAAEwQQAMAEAQQAMDHIugEAfdlEj9qBKetC+n0K50Z3cQUEADBBAAEATBBAAAATBBAAwAQBBAAwQQABAEwQQAAAEwQQAMAEAQQAMEEAAQBMsBQPgBSa4FH7F8+5h3nUrvScGxcCV0AAABNeAVRdXa2rr75amZmZGjVqlObNm6fGxsaEmpMnT6qiokLDhw/X0KFDtWDBArW0tCS1aQBA+vMKoLq6OlVUVGj79u164403dPr0ac2aNUttbW3xmhUrVuiVV17Rhg0bVFdXp0OHDmn+/PlJbxwAkN4CzjnX3Sd/9tlnGjVqlOrq6jR9+nRFIhGNHDlS69at0/e+9z1J0ieffKKvf/3rqq+v1zXXXHPeOaPRqMLhcHdbAtCr/IdH7T97zj3Oo3ak59yfe9ajI5FIRKFQqNP9PfoMKBKJSJKys7MlSQ0NDTp9+rRKS0vjNZMmTdKYMWNUX1/f4RyxWEzRaDRhAAD6vm4HUHt7u5YvX65rr71WkydPliQ1NzcrIyNDWVlZCbU5OTlqbm7ucJ7q6mqFw+H4KCgo6G5LAIA00u0Aqqio0AcffKD169f3qIGqqipFIpH4OHjwYI/mAwCkh259D2jZsmV69dVXtW3bNo0ePTq+PTc3V6dOnVJra2vCVVBLS4tyc3M7nCsYDCoYDHanDQBAGvO6AnLOadmyZdq4caO2bt2qwsLChP1Tp07V4MGDVVNTE9/W2NioAwcOqKSkJDkdAwD6BK8roIqKCq1bt06bN29WZmZm/HOdcDisIUOGKBwOa/HixaqsrFR2drZCoZDuuusulZSUdOkOOABA/+EVQGvWrJEkzZgxI2H7c889p0WLFkmSHn/8cQ0YMEALFixQLBbT7Nmz9fTTTyelWQBA39Gj7wGlAt8DAi6sMR61lyjLa+6P9YhHdbbX3NKvPGr/6Dk3kiGl3wMCAKC7CCAAgAkCCABgggACAJgggAAAJgggAIAJAggAYIIAAgCYIIAAACYIIACAiW79HAOA3utWz/pKj9pr1eo5+yaP2j2ec/Oz2emOKyAAgAkCCABgggACAJgggAAAJgggAIAJAggAYIIAAgCYIIAAACYIIACACQIIAGCCAAIAmGAtOCAN/G+P2u97zv2CR+0pz7mlrd7PQP/BFRAAwAQBBAAwQQABAEwQQAAAEwQQAMAEAQQAMEEAAQBMEEAAABMEEADABAEEADDBUjyAgV2e9d/wqH3ec+5yz3ogWbgCAgCYIIAAACYIIACACQIIAGCCAAIAmCCAAAAmCCAAgAkCCABgggACAJgggAAAJgggAIAJ1oJDv/FTz/rveNRO95z7z571t3nUvug5N2CFKyAAgAmvAKqurtbVV1+tzMxMjRo1SvPmzVNjY2NCzYwZMxQIBBLGkiVLkto0ACD9eQVQXV2dKioqtH37dr3xxhs6ffq0Zs2apba2toS6O+64Q4cPH46PVatWJbVpAED68/oMaMuWLQmP165dq1GjRqmhoUHTp//9XfCLL75Yubm5yekQANAn9egzoEgkIknKzs5O2P7CCy9oxIgRmjx5sqqqqnTixIlO54jFYopGowkDAND3dfsuuPb2di1fvlzXXnutJk+eHN9+2223aezYscrPz9eePXt07733qrGxUS+//HKH81RXV+vhhx/ubhsAgDQVcM657jxx6dKleu211/T2229r9OjRndZt3bpVM2fO1L59+zR+/Piz9sdiMcVisfjjaDSqgoKC7rQEnFM634b9kEctt2Gjt4hEIgqFQp3u79YV0LJly/Tqq69q27Zt5wwfSSouLpakTgMoGAwqGAx2pw0AQBrzCiDnnO666y5t3LhRtbW1KiwsPO9zdu/eLUnKy8vrVoMAgL7JK4AqKiq0bt06bd68WZmZmWpubpYkhcNhDRkyRPv379e6dev03e9+V8OHD9eePXu0YsUKTZ8+XUVFRSk5AABAevIKoDVr1kj68sum/+i5557TokWLlJGRoTfffFNPPPGE2traVFBQoAULFuj+++9PWsMAgL6h2zchpEo0GlU4HLZuA2ninYe6Xvvub/zm3vtR12uzz1+S4Oee9UA6Ot9NCKwFBwAwQQABAEwQQAAAEwQQAMAEAQQAMEEAAQBMEEAAABMEEADABAEEADBBAAEATHT7B+mAVBg0zq/+d0u6Xlt92m/umMdSPAD8cQUEADBBAAEATBBAAAATBBAAwAQBBAAwQQABAEwQQAAAEwQQAMAEAQQAMEEAAQBMEEAAABMB55yzbuIfRaNRhcNh6zaQLsZ61P45ZV0A6EAkElEoFOp0P1dAAAATBBAAwAQBBAAwQQABAEwQQAAAEwQQAMAEAQQAMEEAAQBMEEAAABMEEADAxCDrBoAeYXkdIG1xBQQAMEEAAQBMEEAAABMEEADABAEEADBBAAEATBBAAAATBBAAwAQBBAAwQQABAEwQQAAAEwQQAMAEAQQAMOEVQGvWrFFRUZFCoZBCoZBKSkr02muvxfefPHlSFRUVGj58uIYOHaoFCxaopaUl6U0DANKfVwCNHj1ajz76qBoaGrRz507deOONmjt3rj788ENJ0ooVK/TKK69ow4YNqqur06FDhzR//vyUNA4ASHOuh4YNG+aeffZZ19ra6gYPHuw2bNgQ3/fxxx87Sa6+vr7L80UiESeJwWAwGGk+IpHIOf/ed/szoDNnzmj9+vVqa2tTSUmJGhoadPr0aZWWlsZrJk2apDFjxqi+vr7TeWKxmKLRaMIAAPR93gH0/vvva+jQoQoGg1qyZIk2btyoK664Qs3NzcrIyFBWVlZCfU5Ojpqbmzudr7q6WuFwOD4KCgq8DwIAkH68A2jixInavXu3duzYoaVLl6q8vFwfffRRtxuoqqpSJBKJj4MHD3Z7LgBA+hjk+4SMjAxNmDBBkjR16lT94Q9/0JNPPqmFCxfq1KlTam1tTbgKamlpUW5ubqfzBYNBBYNB/84BAGmtx98Dam9vVywW09SpUzV48GDV1NTE9zU2NurAgQMqKSnp6csAAPoYryugqqoqlZWVacyYMTp27JjWrVun2tpavf766wqHw1q8eLEqKyuVnZ2tUCiku+66SyUlJbrmmmtS1T8AIE15BdCRI0f0/e9/X4cPH1Y4HFZRUZFef/11fec735EkPf744xowYIAWLFigWCym2bNn6+mnn05J4wCA9BZwzjnrJv5RNBpVOBy2bgMA0EORSEShUKjT/awFBwAwQQABAEwQQAAAEwQQAMAEAQQAMEEAAQBMEEAAABMEEADABAEEADDR6wKoly3MAADopvP9Pe91AXTs2DHrFgAASXC+v+e9bi249vZ2HTp0SJmZmQoEAvHt0WhUBQUFOnjw4DnXFkp3HGff0R+OUeI4+5pkHKdzTseOHVN+fr4GDOj8Osf7B+lSbcCAARo9enSn+0OhUJ8++X/DcfYd/eEYJY6zr+npcXZlUele9xYcAKB/IIAAACbSJoCCwaBWrlypYDBo3UpKcZx9R384Ronj7Gsu5HH2upsQAAD9Q9pcAQEA+hYCCABgggACAJgggAAAJtImgFavXq1LL71UF110kYqLi/Xuu+9at5RUDz30kAKBQMKYNGmSdVs9sm3bNt10003Kz89XIBDQpk2bEvY75/Tggw8qLy9PQ4YMUWlpqfbu3WvTbA+c7zgXLVp01rmdM2eOTbPdVF1drauvvlqZmZkaNWqU5s2bp8bGxoSakydPqqKiQsOHD9fQoUO1YMECtbS0GHXcPV05zhkzZpx1PpcsWWLUcfesWbNGRUVF8S+blpSU6LXXXovvv1DnMi0C6KWXXlJlZaVWrlyp9957T1OmTNHs2bN15MgR69aS6sorr9Thw4fj4+2337ZuqUfa2to0ZcoUrV69usP9q1at0lNPPaVnnnlGO3bs0CWXXKLZs2fr5MmTF7jTnjnfcUrSnDlzEs7tiy++eAE77Lm6ujpVVFRo+/bteuONN3T69GnNmjVLbW1t8ZoVK1bolVde0YYNG1RXV6dDhw5p/vz5hl3768pxStIdd9yRcD5XrVpl1HH3jB49Wo8++qgaGhq0c+dO3XjjjZo7d64+/PBDSRfwXLo0MG3aNFdRURF/fObMGZefn++qq6sNu0qulStXuilTpli3kTKS3MaNG+OP29vbXW5urvvVr34V39ba2uqCwaB78cUXDTpMjq8ep3POlZeXu7lz55r0kypHjhxxklxdXZ1z7stzN3jwYLdhw4Z4zccff+wkufr6eqs2e+yrx+mcc9/+9rfdj3/8Y7umUmTYsGHu2WefvaDnstdfAZ06dUoNDQ0qLS2NbxswYIBKS0tVX19v2Fny7d27V/n5+Ro3bpxuv/12HThwwLqllGlqalJzc3PCeQ2HwyouLu5z51WSamtrNWrUKE2cOFFLly7V0aNHrVvqkUgkIknKzs6WJDU0NOj06dMJ53PSpEkaM2ZMWp/Prx7n37zwwgsaMWKEJk+erKqqKp04ccKivaQ4c+aM1q9fr7a2NpWUlFzQc9nrFiP9qs8//1xnzpxRTk5OwvacnBx98sknRl0lX3FxsdauXauJEyfq8OHDevjhh3X99dfrgw8+UGZmpnV7Sdfc3CxJHZ7Xv+3rK+bMmaP58+ersLBQ+/fv189+9jOVlZWpvr5eAwcOtG7PW3t7u5YvX65rr71WkydPlvTl+czIyFBWVlZCbTqfz46OU5Juu+02jR07Vvn5+dqzZ4/uvfdeNTY26uWXXzbs1t/777+vkpISnTx5UkOHDtXGjRt1xRVXaPfu3RfsXPb6AOovysrK4v8uKipScXGxxo4dq9/85jdavHixYWfoqVtuuSX+76uuukpFRUUaP368amtrNXPmTMPOuqeiokIffPBB2n9GeT6dHeedd94Z//dVV12lvLw8zZw5U/v379f48eMvdJvdNnHiRO3evVuRSES//e1vVV5errq6ugvaQ69/C27EiBEaOHDgWXdgtLS0KDc316ir1MvKytLll1+uffv2WbeSEn87d/3tvErSuHHjNGLEiLQ8t8uWLdOrr76qt956K+FnU3Jzc3Xq1Cm1trYm1Kfr+ezsODtSXFwsSWl3PjMyMjRhwgRNnTpV1dXVmjJlip588skLei57fQBlZGRo6tSpqqmpiW9rb29XTU2NSkpKDDtLrePHj2v//v3Ky8uzbiUlCgsLlZubm3Beo9GoduzY0afPqyR9+umnOnr0aFqdW+ecli1bpo0bN2rr1q0qLCxM2D916lQNHjw44Xw2NjbqwIEDaXU+z3ecHdm9e7ckpdX57Eh7e7tisdiFPZdJvaUhRdavX++CwaBbu3at++ijj9ydd97psrKyXHNzs3VrSfOTn/zE1dbWuqamJvf73//elZaWuhEjRrgjR45Yt9Ztx44dc7t27XK7du1yktxjjz3mdu3a5f785z8755x79NFHXVZWltu8ebPbs2ePmzt3rissLHRffPGFced+znWcx44dc3fffberr693TU1N7s0333Tf/OY33WWXXeZOnjxp3XqXLV261IXDYVdbW+sOHz4cHydOnIjXLFmyxI0ZM8Zt3brV7dy505WUlLiSkhLDrv2d7zj37dvnHnnkEbdz507X1NTkNm/e7MaNG+emT59u3Lmf++67z9XV1bmmpia3Z88ed99997lAIOB+97vfOecu3LlMiwByzrlf//rXbsyYMS4jI8NNmzbNbd++3bqlpFq4cKHLy8tzGRkZ7mtf+5pbuHCh27dvn3VbPfLWW285SWeN8vJy59yXt2I/8MADLicnxwWDQTdz5kzX2Nho23Q3nOs4T5w44WbNmuVGjhzpBg8e7MaOHevuuOOOtPvPU0fHJ8k999xz8ZovvvjC/ehHP3LDhg1zF198sbv55pvd4cOH7ZruhvMd54EDB9z06dNddna2CwaDbsKECe6nP/2pi0Qito17+uEPf+jGjh3rMjIy3MiRI93MmTPj4ePchTuX/BwDAMBEr/8MCADQNxFAAAATBBAAwAQBBAAwQQABAEwQQAAAEwQQAMAEAQQAMEEAAQBMEEAAABMEEADABAEEADDx/wGfJ5ToiWQXiQAAAABJRU5ErkJggg=="
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "import imageio.v3 as iio\n",
    "\n",
    "\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "epoch=295\n",
    "img1 = plt.imread(f\"/root/PycharmProjects/IDGEN/napkin_mnist/FID_scores/ncm{epoch:003}/img0.png\")\n",
    "img2 = plt.imread(f\"/root/PycharmProjects/IDGEN/napkin_mnist/FID_scores/ncm{epoch:003}/img1.png\")\n",
    "\n",
    "plt.imshow(img1)\n",
    "# for im_path in glob.glob():\n",
    "#      im = iio.imread(im_path)\n",
    "#      print(im.shape)\n"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "# IDGEN generate P(Y|do(X))\n",
    "\n",
    "Run the following command:\n",
    "python3 gen_final_data.py --pkl_loc=baseline_samples/do_X.pkl --diffuser_loc=final_model_NODROP/wEpoch300/ckpt_300_checkpoint.pt --n_samples=20 --batch_size=200 --device=0 --save_dir=FID_scores/IDGEN250_pkl\n"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "gettnig images for epoch 0\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/min/a/user/.local/lib/python3.10/site-packages/pydot.py:17: UserWarning: `pydot` could not import `dot_parser`, so `pydot` will be unable to parse DOT files. The error was:  No module named 'pyparsing'\n",
      "  warnings.warn(\n",
      "  0%|          | 0/20 [00:00<?, ?it/s]\n",
      "  0%|          | 0/20 [00:00<?, ?it/s]\u001B[A\n",
      "  5%|▌         | 1/20 [00:07<02:17,  7.22s/it]\u001B[A\n",
      " 10%|█         | 2/20 [00:14<02:08,  7.15s/it]\u001B[A\n",
      " 15%|█▌        | 3/20 [00:21<02:01,  7.12s/it]\u001B[A\n",
      " 20%|██        | 4/20 [00:28<01:53,  7.11s/it]\u001B[A\n",
      " 25%|██▌       | 5/20 [00:35<01:46,  7.11s/it]\u001B[A\n",
      " 30%|███       | 6/20 [00:42<01:39,  7.10s/it]\u001B[A\n",
      " 35%|███▌      | 7/20 [00:49<01:32,  7.10s/it]\u001B[A\n",
      " 40%|████      | 8/20 [00:56<01:25,  7.10s/it]\u001B[A\n",
      " 45%|████▌     | 9/20 [01:04<01:18,  7.10s/it]\u001B[A\n",
      " 50%|█████     | 10/20 [01:11<01:11,  7.10s/it]\u001B[A\n",
      " 55%|█████▌    | 11/20 [01:18<01:03,  7.10s/it]\u001B[A\n",
      " 60%|██████    | 12/20 [01:25<00:56,  7.10s/it]\u001B[A\n",
      " 65%|██████▌   | 13/20 [01:32<00:49,  7.10s/it]\u001B[A\n",
      " 70%|███████   | 14/20 [01:39<00:42,  7.10s/it]\u001B[A\n",
      " 75%|███████▌  | 15/20 [01:46<00:35,  7.10s/it]\u001B[A\n",
      " 80%|████████  | 16/20 [01:53<00:28,  7.10s/it]\u001B[A\n",
      " 85%|████████▌ | 17/20 [02:00<00:21,  7.10s/it]\u001B[A\n",
      " 90%|█████████ | 18/20 [02:07<00:14,  7.10s/it]\u001B[A\n",
      " 95%|█████████▌| 19/20 [02:15<00:07,  7.10s/it]\u001B[A\n",
      "100%|██████████| 20/20 [02:22<00:00,  7.11s/it]\u001B[A\n",
      "200it [02:22,  1.40it/s]              \n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Count--> 200\n",
      "gettnig images for epoch 50\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/min/a/user/.local/lib/python3.10/site-packages/pydot.py:17: UserWarning: `pydot` could not import `dot_parser`, so `pydot` will be unable to parse DOT files. The error was:  No module named 'pyparsing'\n",
      "  warnings.warn(\n",
      "  0%|          | 0/20 [00:00<?, ?it/s]\n",
      "  0%|          | 0/20 [00:00<?, ?it/s]\u001B[A\n",
      "  5%|▌         | 1/20 [00:07<02:17,  7.21s/it]\u001B[A\n",
      " 10%|█         | 2/20 [00:14<02:08,  7.16s/it]\u001B[A\n",
      " 15%|█▌        | 3/20 [00:21<02:01,  7.16s/it]\u001B[A\n",
      " 20%|██        | 4/20 [00:28<01:54,  7.15s/it]\u001B[A\n",
      " 25%|██▌       | 5/20 [00:35<01:47,  7.15s/it]\u001B[A\n",
      " 30%|███       | 6/20 [00:42<01:40,  7.15s/it]\u001B[A\n",
      " 35%|███▌      | 7/20 [00:50<01:32,  7.15s/it]\u001B[A\n",
      " 40%|████      | 8/20 [00:57<01:25,  7.15s/it]\u001B[A\n",
      " 45%|████▌     | 9/20 [01:04<01:18,  7.15s/it]\u001B[A\n",
      " 50%|█████     | 10/20 [01:11<01:11,  7.15s/it]\u001B[A\n",
      " 55%|█████▌    | 11/20 [01:18<01:04,  7.15s/it]\u001B[A\n",
      " 60%|██████    | 12/20 [01:25<00:57,  7.15s/it]\u001B[A\n",
      " 65%|██████▌   | 13/20 [01:33<00:50,  7.16s/it]\u001B[A\n",
      " 70%|███████   | 14/20 [01:40<00:42,  7.16s/it]\u001B[A\n",
      " 75%|███████▌  | 15/20 [01:47<00:35,  7.16s/it]\u001B[A\n",
      " 80%|████████  | 16/20 [01:54<00:28,  7.15s/it]\u001B[A\n",
      " 85%|████████▌ | 17/20 [02:01<00:21,  7.15s/it]\u001B[A\n",
      " 90%|█████████ | 18/20 [02:08<00:14,  7.15s/it]\u001B[A\n",
      " 95%|█████████▌| 19/20 [02:15<00:07,  7.15s/it]\u001B[A\n",
      "100%|██████████| 20/20 [02:23<00:00,  7.15s/it]\u001B[A\n",
      "200it [02:23,  1.39it/s]              \n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Count--> 200\n",
      "gettnig images for epoch 100\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/min/a/user/.local/lib/python3.10/site-packages/pydot.py:17: UserWarning: `pydot` could not import `dot_parser`, so `pydot` will be unable to parse DOT files. The error was:  No module named 'pyparsing'\n",
      "  warnings.warn(\n",
      "  0%|          | 0/20 [00:00<?, ?it/s]\n",
      "  0%|          | 0/20 [00:00<?, ?it/s]\u001B[A\n",
      "  5%|▌         | 1/20 [00:07<02:17,  7.22s/it]\u001B[A\n",
      " 10%|█         | 2/20 [00:14<02:08,  7.15s/it]\u001B[A\n",
      " 15%|█▌        | 3/20 [00:21<02:01,  7.13s/it]\u001B[A\n",
      " 20%|██        | 4/20 [00:28<01:53,  7.12s/it]\u001B[A\n",
      " 25%|██▌       | 5/20 [00:35<01:46,  7.11s/it]\u001B[A\n",
      " 30%|███       | 6/20 [00:42<01:39,  7.11s/it]\u001B[A\n",
      " 35%|███▌      | 7/20 [00:49<01:32,  7.11s/it]\u001B[A\n",
      " 40%|████      | 8/20 [00:56<01:25,  7.11s/it]\u001B[A\n",
      " 45%|████▌     | 9/20 [01:04<01:18,  7.10s/it]\u001B[A\n",
      " 50%|█████     | 10/20 [01:11<01:11,  7.10s/it]\u001B[A\n",
      " 55%|█████▌    | 11/20 [01:18<01:03,  7.10s/it]\u001B[A\n",
      " 60%|██████    | 12/20 [01:25<00:56,  7.10s/it]\u001B[A\n",
      " 65%|██████▌   | 13/20 [01:32<00:49,  7.10s/it]\u001B[A\n",
      " 70%|███████   | 14/20 [01:39<00:42,  7.10s/it]\u001B[A\n",
      " 75%|███████▌  | 15/20 [01:46<00:35,  7.10s/it]\u001B[A\n",
      " 80%|████████  | 16/20 [01:53<00:28,  7.10s/it]\u001B[A\n",
      " 85%|████████▌ | 17/20 [02:00<00:21,  7.10s/it]\u001B[A\n",
      " 90%|█████████ | 18/20 [02:07<00:14,  7.10s/it]\u001B[A\n",
      " 95%|█████████▌| 19/20 [02:15<00:07,  7.10s/it]\u001B[A\n",
      "100%|██████████| 20/20 [02:22<00:00,  7.11s/it]\u001B[A\n",
      "200it [02:22,  1.40it/s]              \n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Count--> 200\n",
      "gettnig images for epoch 150\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/min/a/user/.local/lib/python3.10/site-packages/pydot.py:17: UserWarning: `pydot` could not import `dot_parser`, so `pydot` will be unable to parse DOT files. The error was:  No module named 'pyparsing'\n",
      "  warnings.warn(\n",
      "  0%|          | 0/20 [00:00<?, ?it/s]\n",
      "  0%|          | 0/20 [00:00<?, ?it/s]\u001B[A\n",
      "  5%|▌         | 1/20 [00:07<02:17,  7.22s/it]\u001B[A\n",
      " 10%|█         | 2/20 [00:14<02:08,  7.15s/it]\u001B[A\n",
      " 15%|█▌        | 3/20 [00:21<02:01,  7.13s/it]\u001B[A\n",
      " 20%|██        | 4/20 [00:28<01:53,  7.12s/it]\u001B[A\n",
      " 25%|██▌       | 5/20 [00:35<01:46,  7.11s/it]\u001B[A\n",
      " 30%|███       | 6/20 [00:42<01:39,  7.11s/it]\u001B[A\n",
      " 35%|███▌      | 7/20 [00:49<01:32,  7.11s/it]\u001B[A\n",
      " 40%|████      | 8/20 [00:56<01:25,  7.11s/it]\u001B[A\n",
      " 45%|████▌     | 9/20 [01:04<01:18,  7.10s/it]\u001B[A\n",
      " 50%|█████     | 10/20 [01:11<01:11,  7.10s/it]\u001B[A\n",
      " 55%|█████▌    | 11/20 [01:18<01:03,  7.10s/it]\u001B[A\n",
      " 60%|██████    | 12/20 [01:25<00:56,  7.10s/it]\u001B[A\n",
      " 65%|██████▌   | 13/20 [01:32<00:49,  7.10s/it]\u001B[A\n",
      " 70%|███████   | 14/20 [01:39<00:42,  7.10s/it]\u001B[A\n",
      " 75%|███████▌  | 15/20 [01:46<00:35,  7.10s/it]\u001B[A\n",
      " 80%|████████  | 16/20 [01:53<00:28,  7.10s/it]\u001B[A\n",
      " 85%|████████▌ | 17/20 [02:00<00:21,  7.10s/it]\u001B[A\n",
      " 90%|█████████ | 18/20 [02:07<00:14,  7.10s/it]\u001B[A\n",
      " 95%|█████████▌| 19/20 [02:15<00:07,  7.10s/it]\u001B[A\n",
      "100%|██████████| 20/20 [02:22<00:00,  7.11s/it]\u001B[A\n",
      "200it [02:22,  1.40it/s]              \n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Count--> 200\n",
      "gettnig images for epoch 200\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/min/a/user/.local/lib/python3.10/site-packages/pydot.py:17: UserWarning: `pydot` could not import `dot_parser`, so `pydot` will be unable to parse DOT files. The error was:  No module named 'pyparsing'\n",
      "  warnings.warn(\n",
      "  0%|          | 0/20 [00:00<?, ?it/s]\n",
      "  0%|          | 0/20 [00:00<?, ?it/s]\u001B[A\n",
      "  5%|▌         | 1/20 [00:07<02:17,  7.22s/it]\u001B[A\n",
      " 10%|█         | 2/20 [00:14<02:08,  7.16s/it]\u001B[A\n",
      " 15%|█▌        | 3/20 [00:21<02:01,  7.16s/it]\u001B[A\n",
      " 20%|██        | 4/20 [00:28<01:54,  7.16s/it]\u001B[A\n",
      " 25%|██▌       | 5/20 [00:35<01:47,  7.15s/it]\u001B[A\n",
      " 30%|███       | 6/20 [00:42<01:40,  7.15s/it]\u001B[A\n",
      " 35%|███▌      | 7/20 [00:50<01:32,  7.15s/it]\u001B[A\n",
      " 40%|████      | 8/20 [00:57<01:25,  7.15s/it]\u001B[A\n",
      " 45%|████▌     | 9/20 [01:04<01:18,  7.15s/it]\u001B[A\n",
      " 50%|█████     | 10/20 [01:11<01:11,  7.15s/it]\u001B[A\n",
      " 55%|█████▌    | 11/20 [01:18<01:04,  7.15s/it]\u001B[A\n",
      " 60%|██████    | 12/20 [01:25<00:57,  7.15s/it]\u001B[A\n",
      " 65%|██████▌   | 13/20 [01:33<00:50,  7.15s/it]\u001B[A\n",
      " 70%|███████   | 14/20 [01:40<00:42,  7.15s/it]\u001B[A\n",
      " 75%|███████▌  | 15/20 [01:47<00:35,  7.15s/it]\u001B[A\n",
      " 80%|████████  | 16/20 [01:54<00:28,  7.15s/it]\u001B[A\n",
      " 85%|████████▌ | 17/20 [02:01<00:21,  7.15s/it]\u001B[A\n",
      " 90%|█████████ | 18/20 [02:08<00:14,  7.15s/it]\u001B[A\n",
      " 95%|█████████▌| 19/20 [02:15<00:07,  7.15s/it]\u001B[A\n",
      "100%|██████████| 20/20 [02:23<00:00,  7.15s/it]\u001B[A\n",
      "200it [02:23,  1.39it/s]              \n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Count--> 200\n"
     ]
    }
   ],
   "source": [
    "import os\n",
    "for epoch in [0,50,100,150,200]:\n",
    "    print('gettnig images for epoch', epoch)\n",
    "    cmd =f\"python3 gen_final_data.py --pkl_loc=baseline_samples/do_X.pkl --diffuser_loc=final_model_NODROP/wEpoch300/ckpt_300_checkpoint.pt --n_samples=20 --batch_size=200 --device=0 --save_dir=FID_scores/IDGEN{epoch:003}_pkl\"\n",
    "    os.system(cmd)"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Cur epoch 0\n",
      "Cur epoch 50\n",
      "Cur epoch 100\n",
      "Cur epoch 150\n",
      "Cur epoch 200\n",
      "Cur epoch 250\n",
      "Cur epoch 300\n"
     ]
    }
   ],
   "source": [
    "for epoch in [0,50,100,150,200, 250, 300]:\n",
    "    print('Cur epoch', epoch)\n",
    "\n",
    "    file = f'./FID_scores/IDGEN{epoch:003}_pkl/final_W2XY.pkl'\n",
    "\n",
    "    with open(file, 'rb') as f:\n",
    "        ydox = pickle.load(f)\n",
    "\n",
    "    cur_images= ydox['Y'].view(-1, 3,32,32)\n",
    "\n",
    "    img_save_path= f\"/root/PycharmProjects/IDGEN/napkin_mnist/FID_scores/IDGEN{epoch:003}\"\n",
    "    os.makedirs(img_save_path, exist_ok=True)\n",
    "\n",
    "    for iter, img in enumerate(cur_images):\n",
    "        save_image(img, f'{img_save_path}/img{iter}.png')\n"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "outputs": [
    {
     "data": {
      "text/plain": "torch.Size([4000, 3, 32, 32])"
     },
     "execution_count": 22,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "cur_images.shape"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [],
   "metadata": {
    "collapsed": false
   }
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}
