{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "\n",
    "import os\n",
    "import torch\n",
    "import argparse\n",
    "import itertools\n",
    "import numpy as np\n",
    "# from unet import Unet\n",
    "from tqdm import tqdm\n",
    "import torch.optim as optim\n",
    "from cfg.diffusion import GaussianDiffusion\n",
    "from torchvision.utils import save_image\n",
    "from cfg.utils import get_named_beta_schedule\n",
    "from cfg.embedding import ConditionalEmbedding, MNISTEmbedding\n",
    "from cfg.Scheduler import GradualWarmupScheduler\n",
    "\n",
    "import sys; sys.path.append('../retrain_trick'); sys.path.append('../Morpho-MNIST')\n",
    "#from dataloader_cifar import load_data, transback\n",
    "# from gen_retrain_trick import load_data, transback, RetrainTrickDataset\n",
    "from cfg.dataloader_pickle import PickleDataset, transback, load_data\n",
    "from torch.nn.parallel import DistributedDataParallel as DDP\n",
    "from torch.distributed import get_rank, init_process_group, destroy_process_group, all_gather, get_world_size"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "outputs": [
    {
     "data": {
      "text/plain": "Namespace(train_pkl='/root/PycharmProjects/IDGEN/napkin_mnist/base_data/napkin_mnist_train.pkl', val_pkl='/root/PycharmProjects/IDGEN/napkin_mnist/base_data/napkin_mnist_train.pkl', datakey=None, condkey=None, batchsize=256, numworkers=4, inch=3, modch=64, T=1000, outch=3, chmul=[1, 2, 2, 2], numres=2, cdim=64, useconv=True, droprate=0.1, dtype=torch.float32, lr=0.0002, w=1.8, v=0.3, epoch=10, multiplier=2.5, threshold=0.1, interval=9, moddir='imgcond_model', samdir='imgcond_sample', genbatch=80, clsnum=10, num_steps=50, eta=0, select='linear', ddim=False, local_rank=-1)"
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from path_constant import project_root\n",
    "\n",
    "# several hyperparameters for model\n",
    "parser = argparse.ArgumentParser(description='test for diffusion model')\n",
    "parser.add_argument('--train_pkl', type=str, default=f\"{project_root}/napkin_mnist/base_data/napkin_mnist_train.pkl\")\n",
    "parser.add_argument('--val_pkl', type=str, default=f\"{project_root}/napkin_mnist/base_data/napkin_mnist_train.pkl\")\n",
    "parser.add_argument('--datakey', type=str, help='which of the data keys is the one we want to generate')\n",
    "parser.add_argument('--condkey', type=str, help='which of the data keys is the one we use for conditioning')\n",
    "parser.add_argument('--batchsize',type=int,default=256,help='batch size per device for training Unet model')\n",
    "parser.add_argument('--numworkers',type=int,default=4,help='num workers for training Unet model')\n",
    "parser.add_argument('--inch',type=int,default=3,help='input channels for Unet model')\n",
    "parser.add_argument('--modch',type=int,default=64,help='model channels for Unet model')\n",
    "parser.add_argument('--T',type=int,default=1000,help='timesteps for Unet model')\n",
    "parser.add_argument('--outch',type=int,default=3,help='output channels for Unet model')\n",
    "parser.add_argument('--chmul',type=list,default=[1,2,2,2],help='architecture parameters training Unet model')\n",
    "parser.add_argument('--numres',type=int,default=2,help='number of resblocks for each block in Unet model')\n",
    "parser.add_argument('--cdim',type=int,default=64,help='dimension of conditional embedding')\n",
    "parser.add_argument('--useconv',type=bool,default=True,help='whether use convlution in downsample')\n",
    "parser.add_argument('--droprate',type=float,default=0.1,help='dropout rate for model')\n",
    "parser.add_argument('--dtype',default=torch.float32)\n",
    "parser.add_argument('--lr',type=float,default=2e-4,help='learning rate')\n",
    "parser.add_argument('--w',type=float,default=1.8,help='hyperparameters for classifier-free guidance strength')\n",
    "parser.add_argument('--v',type=float,default=0.3,help='hyperparameters for the variance of posterior distribution')\n",
    "parser.add_argument('--epoch',type=int,default=10,help='epochs for training')\n",
    "parser.add_argument('--multiplier',type=float,default=2.5,help='multiplier for warmup')\n",
    "parser.add_argument('--threshold',type=float,default=0.1,help='threshold for classifier-free guidance')\n",
    "parser.add_argument('--interval',type=int,default=9,help='epoch interval between two evaluations')\n",
    "parser.add_argument('--moddir',type=str,default='imgcond_model',help='model addresses')\n",
    "parser.add_argument('--samdir',type=str,default='imgcond_sample',help='sample addresses')\n",
    "parser.add_argument('--genbatch',type=int,default=80,help='batch size for sampling process')\n",
    "parser.add_argument('--clsnum',type=int,default=10,help='num of label classes')\n",
    "parser.add_argument('--num_steps',type=int,default=50,help='sampling steps for DDIM')\n",
    "parser.add_argument('--eta',type=float,default=0,help='eta for variance during DDIM sampling process')\n",
    "parser.add_argument('--select',type=str,default='linear',help='selection stragies for DDIM')\n",
    "parser.add_argument('--ddim',type=lambda x:(str(x).lower() in ['true','1', 'yes']),default=False,help='whether to use ddim')\n",
    "parser.add_argument('--local_rank',default=-1,type=int,help='node rank for distributed training')\n",
    "\n",
    "# args = parser.parse_args()\n",
    "\n",
    "params, unknown = parser.parse_known_args()\n",
    "\n",
    "params\n"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "from torch.utils.data import DataLoader\n",
    "from torch.utils.data.distributed import DistributedSampler\n",
    "\n",
    "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
    "\n",
    "def load_data(dataset: PickleDataset, batchsize: int)-> tuple[DataLoader, DistributedSampler]:\n",
    "        trainloader = DataLoader(dataset,\n",
    "                                 batch_size=batchsize,\n",
    "                                 shuffle=True,\n",
    "                                 drop_last=True)\n",
    "        return trainloader\n",
    "\n",
    "# load data\n",
    "train_data = PickleDataset(params.train_pkl)\n",
    "val_data = PickleDataset(params.val_pkl)\n",
    "dataloader = load_data(train_data, params.batchsize)\n",
    "val_loader = load_data(val_data, params.genbatch // torch.cuda.device_count())"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "\n",
    "def cycler(loader):\n",
    "    while True:\n",
    "        for batch in loader:\n",
    "            yield batch\n",
    "\n",
    "\n",
    "import torch.nn.functional as F\n",
    "\n",
    "def get_parent_embedding(datakey, cemblayer, batch):\n",
    "\n",
    "# W1->W2a\n",
    "# W1->W2b\n",
    "# W2a ->X <- W2b\n",
    "# X -> Y\n",
    "\n",
    "    cemb=None\n",
    "    if datakey==\"X\":\n",
    "        cemb = cemblayer(batch['W2a'].to(device),batch['W2b'].to(device))\n",
    "        cemb = F.dropout1d(cemb, params.threshold)\n",
    "\n",
    "    elif datakey==\"Y\":\n",
    "        lab = batch['X'].to(device)\n",
    "        cemb = cemblayer(lab)\n",
    "        cemb = F.dropout1d(cemb, params.threshold)\n",
    "\n",
    "    return cemb\n",
    "\n",
    "\n",
    "def sample(datakey, diffusion, cemblayer, batch_size, parent_batch):\n",
    "\n",
    "    diffusion.model.eval()\n",
    "    if cemblayer!=None:\n",
    "        cemblayer.eval()\n",
    "\n",
    "    with torch.no_grad():\n",
    "        cemb= get_parent_embedding(datakey, cemblayer, parent_batch)\n",
    "\n",
    "        genshape = (batch_size , 3, 32, 32)\n",
    "        if params.ddim:\n",
    "            generated = diffusion.ddim_sample(genshape, params.num_steps, params.eta, params.select, cemb = cemb)\n",
    "        else:\n",
    "            generated = diffusion.sample(genshape, cemb = cemb)\n",
    "\n",
    "        # cond = transback(cond)\n",
    "        img = transback(generated)\n",
    "\n",
    "        final_imgs = torch.cat([img], dim=1) #(b, 9, 32, 32)   #user\n",
    "        final_imgs = final_imgs.reshape(-1, 3, 32, 32).contiguous()\n",
    "\n",
    "    return final_imgs\n",
    "\n",
    "\n",
    "def train(net, cemblayer, params,moddir, samdir):\n",
    "    # load last epoch\n",
    "    lastpath = os.path.join(moddir,f'last_epoch.pt')\n",
    "    if os.path.exists(lastpath):\n",
    "        lastepc = torch.load(lastpath)['last_epoch']\n",
    "        # load checkpoints\n",
    "        checkpoint = torch.load(os.path.join(moddir, f'ckpt_{lastepc}_checkpoint.pt'), map_location='cpu')\n",
    "        net.load_state_dict(checkpoint['net'])\n",
    "        if cemblayer!= None:\n",
    "            cemblayer.load_state_dict(checkpoint['cemblayer'])\n",
    "    else:\n",
    "        lastepc = 0\n",
    "    betas = get_named_beta_schedule(num_diffusion_timesteps = params.T)\n",
    "    diffusion = GaussianDiffusion(\n",
    "                    dtype = params.dtype,\n",
    "                    model = net,\n",
    "                    betas = betas,\n",
    "                    w = params.w,\n",
    "                    v = params.v,\n",
    "                    device = device\n",
    "                )\n",
    "\n",
    "    # optimizer settings\n",
    "\n",
    "    model_params=list(diffusion.model.parameters())\n",
    "    if cemblayer!= None:\n",
    "        model_params+= list(cemblayer.parameters())\n",
    "    optimizer = torch.optim.AdamW(\n",
    "                    model_params,\n",
    "                    lr = params.lr,\n",
    "                    weight_decay = 1e-4\n",
    "                )\n",
    "\n",
    "    cosineScheduler = optim.lr_scheduler.CosineAnnealingLR(\n",
    "                            optimizer = optimizer,\n",
    "                            T_max = params.epoch,\n",
    "                            eta_min = 0,\n",
    "                            last_epoch = -1\n",
    "                        )\n",
    "    warmUpScheduler = GradualWarmupScheduler(\n",
    "                            optimizer = optimizer,\n",
    "                            multiplier = params.multiplier,\n",
    "                            warm_epoch = params.epoch // 10,\n",
    "                            after_scheduler = cosineScheduler,\n",
    "                            last_epoch = lastepc\n",
    "                        )\n",
    "    if lastepc != 0:\n",
    "        optimizer.load_state_dict(checkpoint['optimizer'])\n",
    "        warmUpScheduler.load_state_dict(checkpoint['scheduler'])\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "     # training\n",
    "    cnt = torch.cuda.device_count()\n",
    "    for epc in range(lastepc, params.epoch):\n",
    "        # turn into train mode\n",
    "        diffusion.model.train()\n",
    "        if cemblayer!=None:\n",
    "            cemblayer.train()\n",
    "        # sampler.set_epoch(epc)\n",
    "        # batch iterations\n",
    "        # with tqdm(dataloader, dynamic_ncols=True, disable=(local_rank % cnt != 0)) as tqdmDataLoader:\n",
    "\n",
    "        iter=0\n",
    "        with tqdm(dataloader, dynamic_ncols=True) as tqdmDataLoader:\n",
    "            for batch in tqdmDataLoader:\n",
    "                optimizer.zero_grad()\n",
    "                x_0 = batch[params.datakey].to(device)\n",
    "\n",
    "                cemb= get_parent_embedding(params.datakey, cemblayer, batch)\n",
    "\n",
    "                loss = diffusion.trainloss(x_0, cemb = cemb)\n",
    "                loss.backward()\n",
    "                optimizer.step()\n",
    "                iter+=1\n",
    "\n",
    "                if iter==20:\n",
    "                    break\n",
    "\n",
    "                tqdmDataLoader.set_postfix(\n",
    "                    ordered_dict={\n",
    "                        \"epoch\": epc + 1,\n",
    "                        \"loss: \": loss.item(),\n",
    "                        \"batch per device: \":x_0.shape[0],\n",
    "                        \"img shape: \": x_0.shape[1:],\n",
    "                        \"LR\": optimizer.state_dict()['param_groups'][0][\"lr\"]\n",
    "                    }\n",
    "                )\n",
    "        warmUpScheduler.step()\n",
    "        # evaluation and save checkpoint\n",
    "        if (epc + 1) % params.interval == 0:\n",
    "            os.makedirs(moddir, exist_ok=True)\n",
    "            os.makedirs(samdir, exist_ok=True)\n",
    "\n",
    "\n",
    "            # generating samples\n",
    "            # Generates genbatch pictures in 2 columns\n",
    "            # column 0: conditioning image\n",
    "            # column 1: generated image\n",
    "\n",
    "\n",
    "            all_conds = []\n",
    "            all_samples = []\n",
    "            each_device_batch = params.genbatch // cnt\n",
    "            val_batch = next(val_cycler)\n",
    "\n",
    "            final_imgs= sample(params.datakey, diffusion, cemblayer, each_device_batch, val_batch)\n",
    "\n",
    "            save_image(final_imgs, os.path.join(samdir, f'generated_{epc+1}_pict.png'), nrow = 3)\n",
    "            print('Image saved as ',os.path.join(samdir, f'generated_{epc+1}_pict.png'))\n",
    "\n",
    "\n",
    "\n",
    "            # save checkpoints\n",
    "            checkpoint = {\n",
    "                                'net':diffusion.model.state_dict(),\n",
    "                                'optimizer':optimizer.state_dict(),\n",
    "                                'scheduler':warmUpScheduler.state_dict()\n",
    "                            }\n",
    "\n",
    "            if cemblayer!=None:\n",
    "                checkpoint['cemblayer']=cemblayer.state_dict()\n",
    "\n",
    "            torch.save({'last_epoch':epc+1}, os.path.join(moddir,f'last_epoch.pt'))\n",
    "            torch.save(checkpoint, os.path.join(moddir, f'ckpt_{epc+1}_checkpoint.pt'))\n",
    "        torch.cuda.empty_cache()\n",
    "\n",
    "\n",
    "    return diffusion, cemblayer\n",
    "\n",
    "\n",
    "\n"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "\n",
    "\n",
    "from cfg.unet import Unet\n",
    "import copy\n",
    "from cfg.embedding import JointEmbedding2, JointConditionalEmbedding\n",
    "\n",
    "val_cycler = cycler(val_loader)\n",
    "\n",
    "# initialize models\n",
    "net={}\n",
    "for lb in ['W1', 'X', 'Y']:\n",
    "\n",
    "    use_cemb= True\n",
    "    if lb==\"W1\":\n",
    "        use_cemb=False\n",
    "\n",
    "    net[lb] = Unet(\n",
    "                in_ch = params.inch,  # here it is 3 for W1,X,Y\n",
    "                mod_ch = params.modch,\n",
    "                out_ch = params.outch, # here it is 3 for W1,X,Y\n",
    "                ch_mul = params.chmul,\n",
    "                num_res_blocks = params.numres,\n",
    "                cdim = params.cdim,\n",
    "                use_conv = params.useconv,\n",
    "                droprate = params.droprate,\n",
    "                dtype = params.dtype,\n",
    "                use_cemb= use_cemb\n",
    "            )\n",
    "\n",
    "\n",
    "cemblayer={}\n",
    "\n",
    "cemblayer['W1']= None # No parent\n",
    "\n",
    "# X is taking W2a,W2b as input.\n",
    "cemblayer['X'] = JointConditionalEmbedding(num_labels_0=10, num_labels_1=6,\n",
    "                           d_model=params.cdim,\n",
    "                           dim=params.cdim).to(device)\n",
    "\n",
    "# Y is taking image X as input.\n",
    "cemblayer['Y'] = MNISTEmbedding(3, params.cdim, hw=32).to(device)\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "for lb in ['W1','X', 'Y']:\n",
    "# for lb in ['X']:\n",
    "    moddir= os.path.join(params.moddir, lb)\n",
    "    samdir= os.path.join(params.samdir, lb)\n",
    "    params.datakey=lb\n",
    "    print(moddir, samdir)\n",
    "    train(net[lb], cemblayer[lb],  params, moddir, samdir)"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "# Training classifier W1 -> W2a, W1 -> W2b"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "import torch as t\n",
    "import torchvision.datasets as datasets\n",
    "import torchvision.transforms as transforms\n",
    "import torch.nn as nn\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "\n",
    "transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,)),])\n",
    "\n",
    "class Net(nn.Module):\n",
    "    def __init__(self, outdim):\n",
    "        super(Net,self).__init__()\n",
    "        self.linear1 = nn.Linear(3*32*32, 100)\n",
    "        self.linear2 = nn.Linear(100, 50)\n",
    "        self.final = nn.Linear(50, outdim)\n",
    "        self.relu = nn.ReLU()\n",
    "\n",
    "    def forward(self, img): #convert + flatten\n",
    "        x = img.view(-1, 3*32*32)\n",
    "        x = self.relu(self.linear1(x))\n",
    "        x = self.relu(self.linear2(x))\n",
    "        x = self.final(x)\n",
    "        return x\n",
    "net1 = Net(outdim=10).to(device)\n",
    "net2 = Net(outdim=6).to(device)\n",
    "\n",
    "\n",
    "\n"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "cross_el1 = nn.CrossEntropyLoss()\n",
    "cross_el2 = nn.CrossEntropyLoss()\n",
    "optimizer = t.optim.Adam(list(net1.parameters())+ list(net2.parameters()), lr=0.001) #e-1\n",
    "epoch = 20\n",
    "\n",
    "for epoch in range(epoch):\n",
    "    net1.train()\n",
    "    net2.train()\n",
    "\n",
    "\n",
    "    with tqdm(dataloader, dynamic_ncols=True) as tqdmDataLoader:\n",
    "        for batch in tqdmDataLoader:\n",
    "            input= batch['W1'].to(device)\n",
    "            lab1= F.one_hot( batch['W2a'].to(device), num_classes=10).float()\n",
    "            lab2= F.one_hot( batch['W2b'].to(device), num_classes=6).float()\n",
    "\n",
    "            x= input\n",
    "            optimizer.zero_grad()\n",
    "            output1 = net1(x.view(-1, 3*32*32))\n",
    "            output2 = net2(x.view(-1, 3*32*32))\n",
    "            loss = cross_el1(output1, lab1) + cross_el2(output2, lab2)\n",
    "\n",
    "            loss.backward()\n",
    "            optimizer.step()\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "checkpoint = {\n",
    "                    'netW2a':net1.state_dict(),\n",
    "                    'netW2b':net2.state_dict(),\n",
    "                    'optimizer':optimizer.state_dict(),\n",
    "                }\n",
    "\n",
    "moddir= os.path.join(params.moddir, 'W2')\n",
    "os.makedirs(moddir, exist_ok=True)\n",
    "torch.save({'last_epoch':epoch+1}, os.path.join(moddir,f'last_epoch.pt'))\n",
    "torch.save(checkpoint, os.path.join(moddir, f'ckpt_{epoch+1}_checkpoint.pt'))\n"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "val_loader = load_data(val_data, 100)\n",
    "\n",
    "with t.no_grad():\n",
    "    total=0\n",
    "    correct1=0\n",
    "    correct2=0\n",
    "\n",
    "    with tqdm(val_loader, dynamic_ncols=True) as tqdmDataLoader:\n",
    "        for batch in tqdmDataLoader:\n",
    "            input= batch['W1'].to(device)\n",
    "            lab1= F.one_hot( batch['W2a'].to(device), num_classes=10).float()\n",
    "            lab2= F.one_hot( batch['W2b'].to(device), num_classes=6).float()\n",
    "            x= input\n",
    "\n",
    "            output1 = net1(x.view(-1, 3*32*32))\n",
    "            output2 = net2(x.view(-1, 3*32*32))\n",
    "\n",
    "            for idx, i in enumerate(output1):\n",
    "                if t.argmax(i) == t.argmax(lab1[idx]):\n",
    "                    correct1 +=1\n",
    "\n",
    "            for idx, i in enumerate(output2):\n",
    "                if t.argmax(i) == t.argmax(lab2[idx]):\n",
    "                    correct2 +=1\n",
    "\n",
    "\n",
    "                total +=1\n",
    "\n",
    "print(f'accuracy: {round(correct1/total, 3)},  {round(correct2/total, 3)}')\n",
    "\n",
    "\n",
    "n=18\n",
    "plt.imshow(x[n].permute(1,2,0).cpu())\n",
    "plt.show()\n",
    "print(t.argmax(net1(x[n].view(-1, 3*32*32))[0]), lab1[n])\n",
    "print(t.argmax(net2(x[n].view(-1, 3*32*32))[0]), lab2[n])\n"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "# Model Load"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "outputs": [
    {
     "ename": "NameError",
     "evalue": "name 'cycler' is not defined",
     "output_type": "error",
     "traceback": [
      "\u001B[0;31m---------------------------------------------------------------------------\u001B[0m",
      "\u001B[0;31mNameError\u001B[0m                                 Traceback (most recent call last)",
      "Cell \u001B[0;32mIn[3], line 34\u001B[0m\n\u001B[1;32m     31\u001B[0m \u001B[38;5;28;01mfrom\u001B[39;00m \u001B[38;5;21;01mcfg\u001B[39;00m\u001B[38;5;21;01m.\u001B[39;00m\u001B[38;5;21;01munet\u001B[39;00m \u001B[38;5;28;01mimport\u001B[39;00m Unet\n\u001B[1;32m     32\u001B[0m \u001B[38;5;28;01mfrom\u001B[39;00m \u001B[38;5;21;01mcfg\u001B[39;00m\u001B[38;5;21;01m.\u001B[39;00m\u001B[38;5;21;01membedding\u001B[39;00m \u001B[38;5;28;01mimport\u001B[39;00m JointEmbedding2, JointConditionalEmbedding\n\u001B[0;32m---> 34\u001B[0m val_cycler \u001B[38;5;241m=\u001B[39m \u001B[43mcycler\u001B[49m(val_loader)\n\u001B[1;32m     36\u001B[0m \u001B[38;5;66;03m# initialize models\u001B[39;00m\n\u001B[1;32m     37\u001B[0m net\u001B[38;5;241m=\u001B[39m{}\n",
      "\u001B[0;31mNameError\u001B[0m: name 'cycler' is not defined"
     ]
    }
   ],
   "source": [
    "\n",
    "def mode_load(net, cemblayer, params,moddir):\n",
    "    # load last epoch\n",
    "    lastpath = os.path.join(moddir,f'last_epoch.pt')\n",
    "    if os.path.exists(lastpath):\n",
    "        lastepc = torch.load(lastpath)['last_epoch']\n",
    "        # load checkpoints\n",
    "        checkpoint = torch.load(os.path.join(moddir, f'ckpt_{lastepc}_checkpoint.pt'), map_location='cpu')\n",
    "        net.load_state_dict(checkpoint['net'])\n",
    "        if cemblayer!= None:\n",
    "            cemblayer.load_state_dict(checkpoint['cemblayer'])\n",
    "\n",
    "        betas = get_named_beta_schedule(num_diffusion_timesteps = params.T)\n",
    "        diffusion = GaussianDiffusion(\n",
    "                        dtype = params.dtype,\n",
    "                        model = net,\n",
    "                        betas = betas,\n",
    "                        w = params.w,\n",
    "                        v = params.v,\n",
    "                        device = device\n",
    "                    )\n",
    "\n",
    "        return diffusion, cemblayer\n",
    "\n",
    "\n",
    "    else:\n",
    "        print('No saved models')\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "from cfg.unet import Unet\n",
    "from cfg.embedding import JointEmbedding2, JointConditionalEmbedding\n",
    "\n",
    "val_cycler = cycler(val_loader)\n",
    "\n",
    "# initialize models\n",
    "net={}\n",
    "for lb in ['W1', 'X', 'Y']:\n",
    "\n",
    "    use_cemb= True\n",
    "    if lb==\"W1\":\n",
    "        use_cemb=False\n",
    "\n",
    "    net[lb] = Unet(\n",
    "                in_ch = params.inch,  # here it is 3 for W1,X,Y\n",
    "                mod_ch = params.modch,\n",
    "                out_ch = params.outch, # here it is 3 for W1,X,Y\n",
    "                ch_mul = params.chmul,\n",
    "                num_res_blocks = params.numres,\n",
    "                cdim = params.cdim,\n",
    "                use_conv = params.useconv,\n",
    "                droprate = params.droprate,\n",
    "                dtype = params.dtype,\n",
    "                use_cemb= use_cemb\n",
    "            )\n",
    "\n",
    "\n",
    "cemblayer={}\n",
    "\n",
    "cemblayer['W1']= None # No parent\n",
    "\n",
    "# X is taking W2a,W2b as input.\n",
    "cemblayer['X'] = JointConditionalEmbedding(num_labels_0=10, num_labels_1=6,\n",
    "                           d_model=params.cdim,\n",
    "                           dim=params.cdim).to(device)\n",
    "\n",
    "# Y is taking image X as input.\n",
    "cemblayer['Y'] = MNISTEmbedding(3, params.cdim, hw=32).to(device)\n",
    "\n",
    "\n",
    "\n",
    "diffusions, cemblayers= {}, {}\n",
    "for lb in ['W1','X', 'Y']:\n",
    "# for lb in ['Y']:\n",
    "    moddir= os.path.join(params.moddir, lb)\n",
    "    samdir= os.path.join(params.samdir, lb)\n",
    "    params.datakey=lb\n",
    "    print(moddir, samdir)\n",
    "    diffusions[lb], cemblayers[lb]= mode_load(net[lb], cemblayer[lb],  params, moddir)\n",
    "    diffusions[lb].eval()\n",
    "\n",
    "    if lb!='W1':\n",
    "        cemblayers[lb].eval()\n",
    "\n",
    "\n",
    "\n",
    "net1 = Net(outdim=10).to(device)\n",
    "net2 = Net(outdim=6).to(device)\n",
    "\n",
    "moddir= os.path.join(params.moddir, 'W2')\n",
    "lastpath = os.path.join(moddir,f'last_epoch.pt')\n",
    "if os.path.exists(lastpath):\n",
    "    lastepc = torch.load(lastpath)['last_epoch']\n",
    "\n",
    "    checkpoint = torch.load(os.path.join(moddir, f'ckpt_{lastepc}_checkpoint.pt'), map_location='cpu')\n",
    "\n",
    "    net1.load_state_dict(checkpoint['netW2a'])\n",
    "    net2.load_state_dict(checkpoint['netW2b'])\n",
    "\n",
    "    net1.eval()\n",
    "    net2.eval()"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "# Inference"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": 56,
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  0%|          | 0/600 [00:00<?, ?it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Generating W1\n",
      "Start generating...\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 23%|██▎       | 229/1000 [00:06<00:22, 33.68it/s]\n"
     ]
    },
    {
     "ename": "KeyboardInterrupt",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001B[0;31m---------------------------------------------------------------------------\u001B[0m",
      "\u001B[0;31mKeyboardInterrupt\u001B[0m                         Traceback (most recent call last)",
      "Cell \u001B[0;32mIn[56], line 8\u001B[0m\n\u001B[1;32m      5\u001B[0m N\u001B[38;5;241m=\u001B[39mval_batch[\u001B[38;5;124m'\u001B[39m\u001B[38;5;124mW1\u001B[39m\u001B[38;5;124m'\u001B[39m]\u001B[38;5;241m.\u001B[39mshape[\u001B[38;5;241m0\u001B[39m]\n\u001B[1;32m      7\u001B[0m \u001B[38;5;28mprint\u001B[39m(\u001B[38;5;124m'\u001B[39m\u001B[38;5;124mGenerating W1\u001B[39m\u001B[38;5;124m'\u001B[39m)\n\u001B[0;32m----> 8\u001B[0m W1\u001B[38;5;241m=\u001B[39m \u001B[43msample\u001B[49m\u001B[43m(\u001B[49m\u001B[38;5;124;43m\"\u001B[39;49m\u001B[38;5;124;43mW1\u001B[39;49m\u001B[38;5;124;43m\"\u001B[39;49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mdiffusions\u001B[49m\u001B[43m[\u001B[49m\u001B[38;5;124;43m'\u001B[39;49m\u001B[38;5;124;43mW1\u001B[39;49m\u001B[38;5;124;43m'\u001B[39;49m\u001B[43m]\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[38;5;28;43;01mNone\u001B[39;49;00m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mN\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mval_batch\u001B[49m\u001B[43m)\u001B[49m\n\u001B[1;32m     10\u001B[0m \u001B[38;5;28mprint\u001B[39m(\u001B[38;5;124m'\u001B[39m\u001B[38;5;124mGenerating W2\u001B[39m\u001B[38;5;124m'\u001B[39m)\n\u001B[1;32m     11\u001B[0m W2a \u001B[38;5;241m=\u001B[39m net1(W1\u001B[38;5;241m.\u001B[39mview(\u001B[38;5;241m-\u001B[39m\u001B[38;5;241m1\u001B[39m, \u001B[38;5;241m3\u001B[39m\u001B[38;5;241m*\u001B[39m\u001B[38;5;241m32\u001B[39m\u001B[38;5;241m*\u001B[39m\u001B[38;5;241m32\u001B[39m))\n",
      "Cell \u001B[0;32mIn[27], line 42\u001B[0m, in \u001B[0;36msample\u001B[0;34m(datakey, diffusion, cemblayer, batch_size, parent_batch)\u001B[0m\n\u001B[1;32m     40\u001B[0m     generated \u001B[38;5;241m=\u001B[39m diffusion\u001B[38;5;241m.\u001B[39mddim_sample(genshape, params\u001B[38;5;241m.\u001B[39mnum_steps, params\u001B[38;5;241m.\u001B[39meta, params\u001B[38;5;241m.\u001B[39mselect, cemb \u001B[38;5;241m=\u001B[39m cemb)\n\u001B[1;32m     41\u001B[0m \u001B[38;5;28;01melse\u001B[39;00m:\n\u001B[0;32m---> 42\u001B[0m     generated \u001B[38;5;241m=\u001B[39m \u001B[43mdiffusion\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43msample\u001B[49m\u001B[43m(\u001B[49m\u001B[43mgenshape\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mcemb\u001B[49m\u001B[43m \u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[43m \u001B[49m\u001B[43mcemb\u001B[49m\u001B[43m)\u001B[49m\n\u001B[1;32m     44\u001B[0m \u001B[38;5;66;03m# cond = transback(cond)\u001B[39;00m\n\u001B[1;32m     45\u001B[0m img \u001B[38;5;241m=\u001B[39m transback(generated)\n",
      "File \u001B[0;32m/root/PycharmProjects/IDGEN/cfg/diffusion.py:177\u001B[0m, in \u001B[0;36mGaussianDiffusion.sample\u001B[0;34m(self, shape, disable_tqdm, **model_kwargs)\u001B[0m\n\u001B[1;32m    175\u001B[0m     tlist \u001B[38;5;241m-\u001B[39m\u001B[38;5;241m=\u001B[39m \u001B[38;5;241m1\u001B[39m\n\u001B[1;32m    176\u001B[0m     \u001B[38;5;28;01mwith\u001B[39;00m torch\u001B[38;5;241m.\u001B[39mno_grad():\n\u001B[0;32m--> 177\u001B[0m         x_t \u001B[38;5;241m=\u001B[39m \u001B[38;5;28;43mself\u001B[39;49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mp_sample\u001B[49m\u001B[43m(\u001B[49m\u001B[43mx_t\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mtlist\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[43mmodel_kwargs\u001B[49m\u001B[43m)\u001B[49m\n\u001B[1;32m    178\u001B[0m x_t \u001B[38;5;241m=\u001B[39m torch\u001B[38;5;241m.\u001B[39mclamp(x_t, \u001B[38;5;241m-\u001B[39m\u001B[38;5;241m1\u001B[39m, \u001B[38;5;241m1\u001B[39m)\n\u001B[1;32m    179\u001B[0m \u001B[38;5;28;01mif\u001B[39;00m local_rank \u001B[38;5;241m==\u001B[39m \u001B[38;5;241m0\u001B[39m \u001B[38;5;129;01mand\u001B[39;00m \u001B[38;5;129;01mnot\u001B[39;00m disable_tqdm:\n",
      "File \u001B[0;32m/root/PycharmProjects/IDGEN/cfg/diffusion.py:153\u001B[0m, in \u001B[0;36mGaussianDiffusion.p_sample\u001B[0;34m(self, x_t, t, **model_kwargs)\u001B[0m\n\u001B[1;32m    151\u001B[0m B, C \u001B[38;5;241m=\u001B[39m x_t\u001B[38;5;241m.\u001B[39mshape[:\u001B[38;5;241m2\u001B[39m]\n\u001B[1;32m    152\u001B[0m \u001B[38;5;28;01massert\u001B[39;00m t\u001B[38;5;241m.\u001B[39mshape \u001B[38;5;241m==\u001B[39m (B,), \u001B[38;5;124mf\u001B[39m\u001B[38;5;124m\"\u001B[39m\u001B[38;5;124msize of t is not batch size \u001B[39m\u001B[38;5;132;01m{\u001B[39;00mB\u001B[38;5;132;01m}\u001B[39;00m\u001B[38;5;124m\"\u001B[39m\n\u001B[0;32m--> 153\u001B[0m mean, var \u001B[38;5;241m=\u001B[39m \u001B[38;5;28;43mself\u001B[39;49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mp_mean_variance\u001B[49m\u001B[43m(\u001B[49m\u001B[43mx_t\u001B[49m\u001B[43m \u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mt\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[43mmodel_kwargs\u001B[49m\u001B[43m)\u001B[49m\n\u001B[1;32m    154\u001B[0m \u001B[38;5;28;01massert\u001B[39;00m torch\u001B[38;5;241m.\u001B[39misnan(mean)\u001B[38;5;241m.\u001B[39mint()\u001B[38;5;241m.\u001B[39msum() \u001B[38;5;241m==\u001B[39m \u001B[38;5;241m0\u001B[39m, \u001B[38;5;124mf\u001B[39m\u001B[38;5;124m\"\u001B[39m\u001B[38;5;124mnan in tensor mean when t = \u001B[39m\u001B[38;5;132;01m{\u001B[39;00mt[\u001B[38;5;241m0\u001B[39m]\u001B[38;5;132;01m}\u001B[39;00m\u001B[38;5;124m\"\u001B[39m\n\u001B[1;32m    155\u001B[0m \u001B[38;5;28;01massert\u001B[39;00m torch\u001B[38;5;241m.\u001B[39misnan(var)\u001B[38;5;241m.\u001B[39mint()\u001B[38;5;241m.\u001B[39msum() \u001B[38;5;241m==\u001B[39m \u001B[38;5;241m0\u001B[39m, \u001B[38;5;124mf\u001B[39m\u001B[38;5;124m\"\u001B[39m\u001B[38;5;124mnan in tensor var when t = \u001B[39m\u001B[38;5;132;01m{\u001B[39;00mt[\u001B[38;5;241m0\u001B[39m]\u001B[38;5;132;01m}\u001B[39;00m\u001B[38;5;124m\"\u001B[39m\n",
      "File \u001B[0;32m/root/PycharmProjects/IDGEN/cfg/diffusion.py:130\u001B[0m, in \u001B[0;36mGaussianDiffusion.p_mean_variance\u001B[0;34m(self, x_t, t, **model_kwargs)\u001B[0m\n\u001B[1;32m    127\u001B[0m     pred_eps_uncond \u001B[38;5;241m=\u001B[39m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39mmodel(x_t, t, \u001B[38;5;241m*\u001B[39m\u001B[38;5;241m*\u001B[39mmodel_kwargs)\n\u001B[1;32m    128\u001B[0m     pred_eps \u001B[38;5;241m=\u001B[39m (\u001B[38;5;241m1\u001B[39m \u001B[38;5;241m+\u001B[39m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39mw) \u001B[38;5;241m*\u001B[39m pred_eps_cond \u001B[38;5;241m-\u001B[39m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39mw \u001B[38;5;241m*\u001B[39m pred_eps_uncond\n\u001B[0;32m--> 130\u001B[0m \u001B[38;5;28;01massert\u001B[39;00m torch\u001B[38;5;241m.\u001B[39misnan(x_t)\u001B[38;5;241m.\u001B[39mint()\u001B[38;5;241m.\u001B[39msum() \u001B[38;5;241m==\u001B[39m \u001B[38;5;241m0\u001B[39m, \u001B[38;5;124mf\u001B[39m\u001B[38;5;124m\"\u001B[39m\u001B[38;5;124mnan in tensor x_t when t = \u001B[39m\u001B[38;5;132;01m{\u001B[39;00mt[\u001B[38;5;241m0\u001B[39m]\u001B[38;5;132;01m}\u001B[39;00m\u001B[38;5;124m\"\u001B[39m\n\u001B[1;32m    131\u001B[0m \u001B[38;5;28;01massert\u001B[39;00m torch\u001B[38;5;241m.\u001B[39misnan(t)\u001B[38;5;241m.\u001B[39mint()\u001B[38;5;241m.\u001B[39msum() \u001B[38;5;241m==\u001B[39m \u001B[38;5;241m0\u001B[39m, \u001B[38;5;124mf\u001B[39m\u001B[38;5;124m\"\u001B[39m\u001B[38;5;124mnan in tensor t when t = \u001B[39m\u001B[38;5;132;01m{\u001B[39;00mt[\u001B[38;5;241m0\u001B[39m]\u001B[38;5;132;01m}\u001B[39;00m\u001B[38;5;124m\"\u001B[39m\n\u001B[1;32m    132\u001B[0m \u001B[38;5;28;01massert\u001B[39;00m torch\u001B[38;5;241m.\u001B[39misnan(pred_eps)\u001B[38;5;241m.\u001B[39mint()\u001B[38;5;241m.\u001B[39msum() \u001B[38;5;241m==\u001B[39m \u001B[38;5;241m0\u001B[39m, \u001B[38;5;124mf\u001B[39m\u001B[38;5;124m\"\u001B[39m\u001B[38;5;124mnan in tensor pred_eps when t = \u001B[39m\u001B[38;5;132;01m{\u001B[39;00mt[\u001B[38;5;241m0\u001B[39m]\u001B[38;5;132;01m}\u001B[39;00m\u001B[38;5;124m\"\u001B[39m\n",
      "\u001B[0;31mKeyboardInterrupt\u001B[0m: "
     ]
    }
   ],
   "source": [
    "\n",
    "with tqdm(val_loader, dynamic_ncols=True) as tqdmDataLoader:\n",
    "    for val_batch in tqdmDataLoader:\n",
    "        break\n",
    "\n",
    "N=val_batch['W1'].shape[0]\n",
    "\n",
    "print('Generating W1')\n",
    "W1= sample(\"W1\", diffusions['W1'], None, N, val_batch)\n",
    "\n",
    "\n",
    "\n",
    "\n"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": 67,
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Generating W2\n"
     ]
    },
    {
     "data": {
      "text/plain": "tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 0, 1, 0, 1, 1, 1, 1, 1, 0,\n        1, 0, 0, 1, 1, 1, 1, 1, 0, 1, 0, 0, 0, 0, 1, 1, 0, 1, 1, 1, 1, 0, 1, 1,\n        1, 1, 0, 1, 1, 1, 1, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 0,\n        1, 1, 1, 0, 1, 0, 1, 1, 1, 1, 1, 0, 0, 1, 1, 1, 1, 0, 1, 1, 0, 0, 1, 0,\n        1, 1, 0, 1], device='cuda:0')"
     },
     "execution_count": 67,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "\n",
    "print('Generating W2')\n",
    "W2a = net1(W1.view(-1, 3*32*32))\n",
    "W2b = net2(W1.view(-1, 3*32*32))\n",
    "\n",
    "W2a= t.argmax(W2a, dim=1)\n",
    "W2b= t.argmax(W2b, dim=1)\n",
    "\n",
    "# W2a= F.one_hot( W2a, num_classes=10)\n",
    "# W2b= F.one_hot( W2b, num_classes=6)\n",
    "W2b"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": 68,
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Generating X\n",
      "Start generating...\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 1000/1000 [00:59<00:00, 16.88it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "ending sampling process...\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "print('Generating X')\n",
    "X= sample(\"X\", diffusions['X'], cemblayers['X'], N, {'W2a':W2a, 'W2b': W2b})\n",
    "\n"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": 49,
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Generating Y\n",
      "Start generating...\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 1000/1000 [00:59<00:00, 16.83it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "ending sampling process...\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "print('Generating Y')\n",
    "Y= sample(\"Y\", diffusions['Y'], cemblayers['Y'], N, {'X':X})"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [],
   "metadata": {
    "collapsed": false
   }
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}
