{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "\n",
    "import os\n",
    "import torch\n",
    "import argparse\n",
    "import itertools\n",
    "import numpy as np\n",
    "# from unet import Unet\n",
    "from tqdm import tqdm\n",
    "import torch.optim as optim\n",
    "from cfg.diffusion import GaussianDiffusion\n",
    "from torchvision.utils import save_image\n",
    "from cfg.utils import get_named_beta_schedule\n",
    "from cfg.embedding import ConditionalEmbedding, MNISTEmbedding\n",
    "from cfg.Scheduler import GradualWarmupScheduler\n",
    "\n",
    "import sys; sys.path.append('../retrain_trick'); sys.path.append('../Morpho-MNIST')\n",
    "#from dataloader_cifar import load_data, transback\n",
    "# from gen_retrain_trick import load_data, transback, RetrainTrickDataset\n",
    "from cfg.dataloader_pickle import PickleDataset, transback, load_data\n",
    "from torch.nn.parallel import DistributedDataParallel as DDP\n",
    "from torch.distributed import get_rank, init_process_group, destroy_process_group, all_gather, get_world_size\n",
    "import torchvision\n",
    "from path_constant import project_root\n",
    "import torch as t\n",
    "import torchvision.datasets as datasets\n",
    "import torchvision.transforms as transforms\n",
    "import torch.nn as nn\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "from torch.utils.data import DataLoader\n",
    "from torch.utils.data.distributed import DistributedSampler\n",
    "\n",
    "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "\n",
    "# several hyperparameters for model\n",
    "parser = argparse.ArgumentParser(description='test for diffusion model')\n",
    "parser.add_argument('--train_pkl', type=str, default=f\"{project_root}/napkin_mnist/base_data/napkin_mnist_train.pkl\")\n",
    "parser.add_argument('--val_pkl', type=str, default=f\"{project_root}/napkin_mnist/base_data/napkin_mnist_train.pkl\")\n",
    "parser.add_argument('--datakey', type=str, help='which of the data keys is the one we want to generate')\n",
    "parser.add_argument('--condkey', type=str, help='which of the data keys is the one we use for conditioning')\n",
    "parser.add_argument('--batchsize',type=int,default=256,help='batch size per device for training Unet model')\n",
    "parser.add_argument('--numworkers',type=int,default=4,help='num workers for training Unet model')\n",
    "parser.add_argument('--inch',type=int,default=3,help='input channels for Unet model')\n",
    "parser.add_argument('--modch',type=int,default=64,help='model channels for Unet model')\n",
    "parser.add_argument('--T',type=int,default=1000,help='timesteps for Unet model')\n",
    "parser.add_argument('--outch',type=int,default=3,help='output channels for Unet model')\n",
    "parser.add_argument('--chmul',type=list,default=[1,2,2,2],help='architecture parameters training Unet model')\n",
    "parser.add_argument('--numres',type=int,default=2,help='number of resblocks for each block in Unet model')\n",
    "parser.add_argument('--cdim',type=int,default=64,help='dimension of conditional embedding')\n",
    "parser.add_argument('--useconv',type=bool,default=True,help='whether use convlution in downsample')\n",
    "parser.add_argument('--droprate',type=float,default=0.1,help='dropout rate for model')\n",
    "parser.add_argument('--dtype',default=torch.float32)\n",
    "parser.add_argument('--lr',type=float,default=2e-4,help='learning rate')\n",
    "parser.add_argument('--w',type=float,default=1.8,help='hyperparameters for classifier-free guidance strength')\n",
    "parser.add_argument('--v',type=float,default=0.3,help='hyperparameters for the variance of posterior distribution')\n",
    "parser.add_argument('--epoch',type=int,default=50,help='epochs for training')\n",
    "parser.add_argument('--multiplier',type=float,default=2.5,help='multiplier for warmup')\n",
    "parser.add_argument('--threshold',type=float,default=0.1,help='threshold for classifier-free guidance')\n",
    "parser.add_argument('--interval',type=int,default=9,help='epoch interval between two evaluations')\n",
    "parser.add_argument('--moddir',type=str,default=f'{project_root}/Baselines/DiffusionBasedCausalModels/imgcond_model',help='model addresses')\n",
    "parser.add_argument('--samdir',type=str,default=f'{project_root}/Baselines/DiffusionBasedCausalModels/imgcond_sample',help='sample addresses')\n",
    "parser.add_argument('--genbatch',type=int,default=80,help='batch size for sampling process')\n",
    "parser.add_argument('--clsnum',type=int,default=10,help='num of label classes')\n",
    "parser.add_argument('--num_steps',type=int,default=50,help='sampling steps for DDIM')\n",
    "parser.add_argument('--eta',type=float,default=0,help='eta for variance during DDIM sampling process')\n",
    "parser.add_argument('--select',type=str,default='linear',help='selection stragies for DDIM')\n",
    "parser.add_argument('--ddim',type=lambda x:(str(x).lower() in ['true','1', 'yes']),default=True,help='whether to use ddim')\n",
    "parser.add_argument('--local_rank',default=-1,type=int,help='node rank for distributed training')\n",
    "\n",
    "# args = parser.parse_args()\n",
    "\n",
    "params, unknown = parser.parse_known_args()\n",
    "\n",
    "params\n"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "\n",
    "def cycler(loader):\n",
    "    while True:\n",
    "        for batch in loader:\n",
    "            yield batch\n",
    "\n",
    "def load_data(dataset: PickleDataset, batchsize: int)-> tuple[DataLoader, DistributedSampler]:\n",
    "        trainloader = DataLoader(dataset,\n",
    "                                 batch_size=batchsize,\n",
    "                                 shuffle=True,\n",
    "                                 drop_last=True)\n",
    "        return trainloader\n",
    "\n",
    "# load data\n",
    "train_data = PickleDataset(params.train_pkl)\n",
    "val_data = PickleDataset(params.val_pkl)\n",
    "dataloader = load_data(train_data, params.batchsize)\n",
    "val_loader = load_data(val_data, params.genbatch // torch.cuda.device_count())\n",
    "val_cycler = cycler(val_loader)\n"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "\n",
    "\n",
    "\n",
    "\n",
    "import torch.nn.functional as F\n",
    "\n",
    "def get_parent_embedding(datakey, cemblayer, batch):\n",
    "\n",
    "# W1->W2a\n",
    "# W1->W2b\n",
    "# W2a ->X <- W2b\n",
    "# X -> Y\n",
    "\n",
    "    cemb=None\n",
    "    if datakey==\"X\":\n",
    "        cemb = cemblayer(batch['W2a'].to(device),batch['W2b'].to(device))\n",
    "        cemb = F.dropout1d(cemb, params.threshold)\n",
    "\n",
    "    elif datakey==\"Y\":\n",
    "        lab = batch['X'].to(device)\n",
    "        cemb = cemblayer(lab)\n",
    "        cemb = F.dropout1d(cemb, params.threshold)\n",
    "\n",
    "    return cemb\n",
    "\n",
    "\n",
    "def sample(datakey, diffusion, cemblayer, batch_size, parent_batch):\n",
    "\n",
    "    diffusion.model.eval()\n",
    "    if cemblayer!=None:\n",
    "        cemblayer.eval()\n",
    "\n",
    "    with torch.no_grad():\n",
    "        cemb= get_parent_embedding(datakey, cemblayer, parent_batch)\n",
    "\n",
    "        genshape = (batch_size , 3, 32, 32)\n",
    "        if params.ddim:\n",
    "            generated = diffusion.ddim_sample(genshape, params.num_steps, params.eta, params.select, cemb = cemb)\n",
    "        else:\n",
    "            generated = diffusion.sample(genshape, cemb = cemb)\n",
    "\n",
    "        # cond = transback(cond)\n",
    "        img = transback(generated)\n",
    "\n",
    "        final_imgs = torch.cat([img], dim=1) #(b, 9, 32, 32)   #user\n",
    "        final_imgs = final_imgs.reshape(-1, 3, 32, 32).contiguous()\n",
    "\n",
    "    return final_imgs\n",
    "\n",
    "\n"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "# Training classifier W1 -> W2a, W1 -> W2b"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "\n",
    "\n",
    "transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,)),])\n",
    "\n",
    "class Net(nn.Module):\n",
    "    def __init__(self, outdim):\n",
    "        super(Net,self).__init__()\n",
    "        self.linear1 = nn.Linear(3*32*32, 100)\n",
    "        self.linear2 = nn.Linear(100, 50)\n",
    "        self.final = nn.Linear(50, outdim)\n",
    "        self.relu = nn.ReLU()\n",
    "\n",
    "    def forward(self, img): #convert + flatten\n",
    "        x = img.view(-1, 3*32*32)\n",
    "        x = self.relu(self.linear1(x))\n",
    "        x = self.relu(self.linear2(x))\n",
    "        x = self.final(x)\n",
    "        return x\n",
    "net1 = Net(outdim=10).to(device)\n",
    "net2 = Net(outdim=6).to(device)\n",
    "\n",
    "\n",
    "\n"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "\n",
    "from cfg.unet import Unet\n",
    "from cfg.embedding import JointEmbedding2, JointConditionalEmbedding\n",
    "\n",
    "\n",
    "# initialize models\n",
    "net={}\n",
    "for lb in ['W1', 'X', 'Y']:\n",
    "\n",
    "    use_cemb= True\n",
    "    if lb==\"W1\":\n",
    "        use_cemb=False\n",
    "\n",
    "    net[lb] = Unet(\n",
    "                in_ch = params.inch,  # here it is 3 for W1,X,Y\n",
    "                mod_ch = params.modch,\n",
    "                out_ch = params.outch, # here it is 3 for W1,X,Y\n",
    "                ch_mul = params.chmul,\n",
    "                num_res_blocks = params.numres,\n",
    "                cdim = params.cdim,\n",
    "                use_conv = params.useconv,\n",
    "                droprate = params.droprate,\n",
    "                dtype = params.dtype,\n",
    "                use_cemb= use_cemb\n",
    "            )\n",
    "\n",
    "\n",
    "cemblayer={}\n",
    "\n",
    "\n"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "cemblayer['W1']= None # No parent\n",
    "\n",
    "# X is taking W2a,W2b as input.\n",
    "cemblayer['X'] = JointConditionalEmbedding(num_labels_0=10, num_labels_1=6,\n",
    "                           d_model=params.cdim,\n",
    "                           dim=params.cdim).to(device)\n",
    "\n"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "# Y is taking image X as input.\n",
    "cemblayer['Y'] = MNISTEmbedding(3, params.cdim, hw=32).to(device)\n"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "# Model Load"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "\n",
    "def mode_load(net, cemblayer, params,moddir, epoch_fixed= None):\n",
    "    # load last epoch\n",
    "    lastpath = os.path.join(moddir,f'last_epoch.pt')\n",
    "    if os.path.exists(lastpath):\n",
    "        lastepc = torch.load(lastpath)['last_epoch']\n",
    "\n",
    "        if epoch_fixed!=None:\n",
    "            lastepc= epoch_fixed\n",
    "\n",
    "        # load checkpoints\n",
    "        checkpoint = torch.load(os.path.join(moddir, f'ckpt_{lastepc}_checkpoint.pt'), map_location='cpu')\n",
    "        net.load_state_dict(checkpoint['net'])\n",
    "        if cemblayer!= None:\n",
    "            cemblayer.load_state_dict(checkpoint['cemblayer'])\n",
    "\n",
    "        betas = get_named_beta_schedule(num_diffusion_timesteps = params.T)\n",
    "        diffusion = GaussianDiffusion(\n",
    "                        dtype = params.dtype,\n",
    "                        model = net,\n",
    "                        betas = betas,\n",
    "                        w = params.w,\n",
    "                        v = params.v,\n",
    "                        device = device\n",
    "                    )\n",
    "\n",
    "        return diffusion, cemblayer\n",
    "\n",
    "\n",
    "    else:\n",
    "        print('No saved models')\n",
    "\n",
    "\n",
    "diffusions, cemblayers= {}, {}\n",
    "# for lb in ['W1','X', 'Y']:\n",
    "for lb in ['Y']:\n",
    "    moddir= os.path.join(params.moddir, lb)\n",
    "    samdir= os.path.join(params.samdir, lb)\n",
    "    params.datakey=lb\n",
    "    print(moddir, samdir)\n",
    "    diffusions[lb], cemblayers[lb]= mode_load(net[lb], cemblayer[lb],  params, moddir, epoch_fixed=240)   #### Loading model\n",
    "    diffusions[lb].eval()\n",
    "\n",
    "    if lb!='W1':\n",
    "        cemblayers[lb].eval()\n",
    "\n"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "\n",
    "\n",
    "net1 = Net(outdim=10).to(device)\n",
    "net2 = Net(outdim=6).to(device)\n",
    "\n",
    "moddir= os.path.join(params.moddir, 'W2')\n",
    "lastpath = os.path.join(moddir,f'last_epoch.pt')\n",
    "if os.path.exists(lastpath):\n",
    "    lastepc = torch.load(lastpath)['last_epoch']\n",
    "\n",
    "    checkpoint = torch.load(os.path.join(moddir, f'ckpt_{lastepc}_checkpoint.pt'), map_location='cpu')\n",
    "\n",
    "    net1.load_state_dict(checkpoint['netW2a'])\n",
    "    net2.load_state_dict(checkpoint['netW2b'])\n",
    "\n",
    "    net1.eval()\n",
    "    net2.eval()"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "# Inference"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "\n",
    "N=100\n",
    "print('Generating W1')\n",
    "W1= sample(\"W1\", diffusions['W1'], None, N, {})\n",
    "\n"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "\n",
    "print('Generating W2')\n",
    "W2a = net1(W1.view(-1, 3*32*32))\n",
    "W2b = net2(W1.view(-1, 3*32*32))\n",
    "\n",
    "W2a= t.argmax(W2a, dim=1)\n",
    "W2b= t.argmax(W2b, dim=1)\n",
    "\n"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "print('Generating X')\n",
    "X= sample(\"X\", diffusions['X'], cemblayers['X'], N, {'W2a':W2a, 'W2b': W2b})\n",
    "\n",
    "print('Generating Y')\n",
    "Y= sample(\"Y\", diffusions['Y'], cemblayers['Y'], N, {'X':X})"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "def show(id, W1,W2a,W2b,X,Y):\n",
    "    print('No processing. Just printing the values.')\n",
    "    plt.imshow(W1[id].cpu().permute(1,2,0))\n",
    "    plt.show()\n",
    "    print(W2a[id], W2b[id])\n",
    "    plt.imshow(X[id].cpu().permute(1,2,0))\n",
    "    plt.show()\n",
    "    plt.imshow(Y[id].cpu().permute(1,2,0))\n",
    "    plt.show()\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "for batch in dataloader:\n",
    "    break\n",
    "\n",
    "\n",
    "id=torch.randint(0,100, (1,)).item()\n",
    "\n",
    "# show(41, batch['W1'], batch['W2a'], batch['W2b'], batch['X'], batch['Y'])\n"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "\n",
    "import torchvision\n",
    "\n",
    "fake_W1 = torchvision.utils.make_grid(W1, normalize=True)\n",
    "fake_X = torchvision.utils.make_grid(X, normalize=True)\n",
    "fake_Y = torchvision.utils.make_grid(Y, normalize=True)\n",
    "\n",
    "fake= torch.cat([fake_W1, fake_X, fake_Y], dim=2)\n",
    "\n",
    "plt.figure(figsize = (20,8))\n",
    "\n",
    "plt.imshow(fake.cpu().permute(1,2,0))\n",
    "plt.show()"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "from napkin_mnist.train_classifiers import load_classifiers\n",
    "\n",
    "save_dir= f'{project_root}/napkin_mnist/saved_classifier_models/'\n",
    "save_name='napkin_classifer'\n",
    "true_classifiers= load_classifiers(save_dir, save_name)\n",
    "\n",
    "true_classifiers['Y_color']= true_classifiers['Y_color'].to(device)\n",
    "true_classifiers['Y_digit']= true_classifiers['Y_digit'].to(device)\n",
    "true_classifiers.keys()"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "# P(Y|do(X=3))"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "truidx= torch.where((train_data[:]['W2b']==1) & (train_data[:]['W2a']==3))\n",
    "truidx= truidx[0][0:500]\n",
    "truX= train_data[:]['X'][truidx].to(device)\n",
    "truY = train_data[:]['Y'][truidx].to(device)\n",
    "\n",
    "# truYc= train_data[:]['Y_color'][truidx]\n",
    "# bins= torch.bincount(truYc)\n",
    "# print('True P(Y|X)',bins/sum(bins))\n",
    "\n",
    "\n",
    "output = true_classifiers['Y_color'](truY)\n",
    "bins= torch.bincount(t.argmax(output, dim=1))\n",
    "print('True color P(Y|X) from classifier',bins/sum(bins))\n",
    "\n",
    "output = true_classifiers['Y_digit'](truY)\n",
    "bins= torch.bincount(t.argmax(output, dim=1))\n",
    "print('True digit P(Y|X) from classifier',bins/sum(bins))\n"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "intv= truX\n",
    "doY= sample(\"Y\", diffusions['Y'], cemblayers['Y'], intv.shape[0], {'X':intv})"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "fakeX = torchvision.utils.make_grid(intv, normalize=True)\n",
    "fakeY = torchvision.utils.make_grid(doY, normalize=True)\n",
    "\n",
    "fake= torch.cat([fakeX, fakeY], dim=2)\n",
    "plt.figure(figsize = (30,30))\n",
    "\n",
    "plt.imshow(fake.cpu().permute(1,2,0))\n",
    "plt.show()\n",
    "\n",
    "output = true_classifiers['Y_color'](doY)\n",
    "bins= torch.bincount(t.argmax(output, dim=1))\n",
    "print('Pred P(Y|X)',bins/sum(bins))\n",
    "\n",
    "\n",
    "output = true_classifiers['Y_digit'](doY)\n",
    "bins= torch.bincount(t.argmax(output, dim=1))\n",
    "print('Pred digit P(Y|X) from classifier',bins/sum(bins))"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "# Load the common do(X) for all baselines"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "import os\n",
    "import pickle\n",
    "\n",
    "file = f'{project_root}/Baselines/Compare/baseline_samples/do_X.pkl'\n",
    "\n",
    "with open(file, 'rb') as f:\n",
    "    do_x = pickle.load(f)\n",
    "\n",
    "\n",
    "\n",
    "do3= do_x['X'][0:1500]\n",
    "do5= do_x['X'][1500:3000]\n",
    "\n",
    "intv= torch.cat([do3,do3, do5,do5])\n",
    "\n",
    "\n",
    "print('Generating Y|do(X)')\n",
    "\n",
    "intvloader = DataLoader(intv,batch_size=100)\n",
    "result=[]\n",
    "for batch in intvloader:\n",
    "    print(batch.shape)\n",
    "    doY= sample(\"Y\", diffusions['Y'], cemblayers['Y'], batch.shape[0], {'X':batch})\n",
    "    result.append(doY)\n",
    "\n",
    "\n",
    "\n"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "fakeY = torchvision.utils.make_grid(doY, normalize=True)\n",
    "plt.figure(figsize = (20,30))\n",
    "\n",
    "plt.imshow(fakeY.cpu().permute(1,2,0))\n",
    "plt.show()"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "final_YdoX= torch.cat(result)\n",
    "\n",
    "final_result= {'X':do_x['X'], 'Y':final_YdoX}\n",
    "\n",
    "file = f'{project_root}/Baselines/Compare/baseline_samples/diffscmYdoX.pkl'\n",
    "with open(file, 'wb') as handle:\n",
    "    pickle.dump(final_result, handle, protocol=pickle.HIGHEST_PROTOCOL)\n",
    "\n"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "output = true_classifiers['Y_color'](final_YdoX)\n",
    "bins= torch.bincount(t.argmax(output, dim=1))\n",
    "print('Pred P(Y|X)',bins/sum(bins))\n",
    "\n",
    "\n",
    "output = true_classifiers['Y_digit'](final_YdoX)\n",
    "bins= torch.bincount(t.argmax(output, dim=1))\n",
    "print('Pred digit P(Y|X) from classifier',bins/sum(bins))"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "Since the intervention is a red 3 and a red 5, the predicted color is more towards 0,1,2.\n",
    "And the predicted digit is 3 with 45.57% and 5 with 46.07%.\n",
    "So, the model is correctly predicting conditional distribution.\n",
    "However, it is not the same as the interventional distribution. But this model does not have the\n",
    "ability to generate interventional samples where there exists a confounder in the graph."
   ],
   "metadata": {
    "collapsed": false
   }
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}
