{
 "cells": [
  {
   "cell_type": "markdown",
   "source": [
    "\"\"\" Builds a (pytorch) dataset with data constructed according to the structural equation model\n",
    "\n",
    "U1  := Unif({0,1,2,3,4,5}) where this is 0->Red, 1->Green, 2->Blue, 3->Yellow, 4->Magenta, 5->Cyan\n",
    "U2  := Unif({0,1,2}) where 0->Thin, 1->Regular, 2-> Thick\n",
    "D   := The digit [0..9]\n",
    "W1  := an MNIST image with (digit:D, Color:U1, Thickness:U2)\n",
    "W2a := a DIGIT, W1.digit\n",
    "W2b := a MODIFIED COLOR CODE, taken from U1 // 3\n",
    "X   := An MNIST image, rotated 90 degrees with (digit:W2a, color: W2b, thickness:U2)\n",
    "Y   := An MNIST image, reflected, with (digit: X.digit, color: U1, thickness:X.thickness)\n",
    "\n",
    "\n",
    "So the way we will form this is:\n",
    "1.  Sample U1, U2, D\n",
    "2.  Sample W1: pick random mnist image from D, apply color:U1, thickness:U2\n",
    "3a. Sample W2a with support 10: take digit D and apply massart noise to it\n",
    "3b. Sample W2b with support 2: take color U1, and transform to U1 // 3 (and apply massart noise)\n",
    "4.  Sample X: pick random mnist image with Digit W2a, apply color W2b, thickness U2 (randomly massart noise this, too). Rotate 90 degrees\n",
    "5.  Sample Y: take X and recolor with U2. Reflect. Apply random Massart, too\n",
    "\n",
    "\n",
    "(Note: we'll do all the data preparation in numpy because we gotta use morphoMNIST to do these things)\n",
    "\"\"\""
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "outputs": [],
   "source": [
    "import os\n",
    "import torch\n",
    "import argparse\n",
    "import itertools\n",
    "import numpy as np\n",
    "from unet import Unet\n",
    "from tqdm import tqdm\n",
    "import torch.optim as optim\n",
    "from diffusion import GaussianDiffusion\n",
    "from torchvision.utils import save_image\n",
    "from utils import get_named_beta_schedule\n",
    "from embedding import ConditionalEmbedding, MNISTEmbedding, JointEmbedding2\n",
    "from Scheduler import GradualWarmupScheduler\n",
    "\n",
    "import sys; sys.path.append('../retrain_trick'); sys.path.append('../Morpho-MNIST')\n",
    "#from dataloader_cifar import load_data, transback\n",
    "from dataloader_pickle import transback, PickleDataset\n",
    "from torch.nn.parallel import DistributedDataParallel as DDP\n",
    "from torch.distributed import get_rank, init_process_group, destroy_process_group, all_gather, get_world_size\n",
    "import torch.nn.functional as F\n",
    "from torch import Tensor\n",
    "import torch\n",
    "from torchvision import transforms\n",
    "from torch.utils.data import DataLoader, Dataset\n",
    "from torch.utils.data.distributed import DistributedSampler"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "outputs": [],
   "source": [
    "import pickle\n",
    "\n",
    "file = '../napkin_mnist/base_data/napkin_mnist_train.pkl'\n",
    "\n",
    "with open(file, 'rb') as f:\n",
    "    real_data = pickle.load(f)\n",
    "\n"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "outputs": [
    {
     "data": {
      "text/plain": "tensor(1)"
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "max(real_data['W2b'])"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "outputs": [],
   "source": [
    "\n",
    "\n",
    "parser = argparse.ArgumentParser(description='test for diffusion model')\n",
    "parser.add_argument('--train_pkl', type=str, default=\"../napkin_mnist/base_data/napkin_mnist_train.pkl\")\n",
    "parser.add_argument('--val_pkl', type=str, default=\"../napkin_mnist/base_data/napkin_mnist_val.pkl\")\n",
    "parser.add_argument('--datakey',  type=str, default=\"X,Y\", help='comma separated list of keys to be concatenated in channel dimension to make a single joint model')\n",
    "parser.add_argument('--labkey_0', type=str, default=\"W2a\", help='which key contains the discrete conditioning')\n",
    "parser.add_argument('--labkey_1', type=str, default=\"W2b\", help='which key contains the discrete conditioning')\n",
    "parser.add_argument('--condkey', type=str, default=\"W1\", help='which key contains the image conditioning')\n",
    "parser.add_argument('--batchsize',type=int,default=256,help='batch size per device for training Unet model')\n",
    "parser.add_argument('--numworkers',type=int,default=4,help='num workers for training Unet model')\n",
    "parser.add_argument('--inch',type=int,default=6,help='input channels for Unet model')\n",
    "parser.add_argument('--modch',type=int,default=64,help='model channels for Unet model')\n",
    "parser.add_argument('--T',type=int,default=1000,help='timesteps for Unet model')\n",
    "parser.add_argument('--outch',type=int,default=6,help='output channels for Unet model')\n",
    "parser.add_argument('--chmul',type=list,default=[1,2,2,2],help='architecture parameters training Unet model')\n",
    "parser.add_argument('--numres',type=int,default=2,help='number of resblocks for each block in Unet model')\n",
    "parser.add_argument('--cdim',type=int,default=64,help='dimension of conditional embedding')\n",
    "parser.add_argument('--useconv',type=bool,default=True,help='whether use convlution in downsample')\n",
    "parser.add_argument('--droprate',type=float,default=0.1,help='dropout rate for model')\n",
    "parser.add_argument('--dtype',default=torch.float32)\n",
    "parser.add_argument('--lr',type=float,default=2e-4,help='learning rate')\n",
    "parser.add_argument('--w',type=float,default=1.8,help='hyperparameters for classifier-free guidance strength')\n",
    "parser.add_argument('--v',type=float,default=0.3,help='hyperparameters for the variance of posterior distribution')\n",
    "parser.add_argument('--epoch',type=int,default=301,help='epochs for training')\n",
    "parser.add_argument('--multiplier',type=float,default=2.5,help='multiplier for warmup')\n",
    "parser.add_argument('--threshold',type=float,default=0.1,help='threshold for classifier-free guidance')\n",
    "parser.add_argument('--interval',type=int,default=1,help='epoch interval between two evaluations')\n",
    "parser.add_argument('--moddir',type=str,default='../napkin_mnist/synthetic_model',help='model addresses')\n",
    "parser.add_argument('--samdir',type=str,default='../napkin_mnist/synthetic_model_evals',help='sample addresses')\n",
    "parser.add_argument('--genbatch',type=int,default=80,help='batch size for sampling process')\n",
    "parser.add_argument('--clsnum_0',type=int,default=10,help='num of label classes')\n",
    "parser.add_argument('--clsnum_1',type=int,default=2,help='num of label classes')\n",
    "parser.add_argument('--num_steps',type=int,default=50,help='sampling steps for DDIM')\n",
    "parser.add_argument('--eta',type=float,default=0,help='eta for variance during DDIM sampling process')\n",
    "parser.add_argument('--select',type=str,default='linear',help='selection stragies for DDIM')\n",
    "parser.add_argument('--ddim',type=lambda x:(str(x).lower() in ['true','1', 'yes']),default=False,help='whether to use ddim')\n",
    "parser.add_argument('--local_rank',default=-1,type=int,help='node rank for distributed training')\n",
    "\n",
    "# args = parser.parse_args()\n",
    "params, unknown = parser.parse_known_args()\n"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "outputs": [
    {
     "data": {
      "text/plain": "'../napkin_mnist/synthetic_model_evals/Run2'"
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "exp_name= \"Run2\"\n",
    "\n",
    "if exp_name not in params.moddir:\n",
    "    params.moddir= params.moddir+'/'+exp_name\n",
    "\n",
    "if exp_name not in params.samdir:\n",
    "    params.samdir= params.samdir+'/'+exp_name\n",
    "params.samdir"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "outputs": [],
   "source": [
    "def cycler(loader):\n",
    "    while True:\n",
    "        for batch in loader:\n",
    "            yield batch\n",
    "\n",
    "\n",
    "def load_data(dataset: PickleDataset, batchsize: int)-> tuple[DataLoader, DistributedSampler]:\n",
    "\t\ttrainloader = DataLoader(dataset,\n",
    "\t\t\t\t\t\t\t     batch_size=batchsize,\n",
    "                                 shuffle=True,\n",
    "\t\t\t\t\t\t\t     drop_last=True)\n",
    "\t\treturn trainloader\n",
    "\n",
    "# set device\n",
    "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
    "\n",
    "# load data\n",
    "train_data = PickleDataset(params.train_pkl)\n",
    "val_data = PickleDataset(params.val_pkl)\n",
    "dataloader = load_data(train_data, params.batchsize)\n",
    "val_loader = load_data(val_data, params.genbatch // torch.cuda.device_count())\n",
    "val_cycler = cycler(val_loader)\n",
    "\n",
    "data_keys = params.datakey.split(',')\n",
    "datacat = lambda batch: torch.cat([batch[k] for k in data_keys], dim=1).contiguous()\n",
    "datashape = None\n",
    "\n",
    "# initialize models\n",
    "net = Unet(\n",
    "            in_ch = params.inch,\n",
    "            mod_ch = params.modch,\n",
    "            out_ch = params.outch,\n",
    "            ch_mul = params.chmul,\n",
    "            num_res_blocks = params.numres,\n",
    "            cdim = params.cdim,\n",
    "            use_conv = params.useconv,\n",
    "            droprate = params.droprate,\n",
    "            dtype = params.dtype\n",
    "        )"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 234/234 [01:04<00:00,  3.62it/s, epoch=1, loss: =0.0592, batch per device: =256, img shape: =torch.Size([6, 32, 32]), LR=0.0002]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Start generating...\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 1000/1000 [00:34<00:00, 28.68it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "ending sampling process...\n",
      "Image saved as  ../napkin_mnist/synthetic_model_evals/Run2/generated_1_pict.png\n",
      "Model saved as  ../napkin_mnist/synthetic_model/Run2/ckpt_1_checkpoint.pt\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 49%|████▊     | 114/234 [00:31<00:32,  3.66it/s, epoch=2, loss: =0.0468, batch per device: =256, img shape: =torch.Size([6, 32, 32]), LR=0.00021]\n"
     ]
    },
    {
     "ename": "KeyboardInterrupt",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001B[0;31m---------------------------------------------------------------------------\u001B[0m",
      "\u001B[0;31mKeyboardInterrupt\u001B[0m                         Traceback (most recent call last)",
      "Cell \u001B[0;32mIn[7], line 74\u001B[0m\n\u001B[1;32m     71\u001B[0m cemb \u001B[38;5;241m=\u001B[39m F\u001B[38;5;241m.\u001B[39mdropout1d(cemb, params\u001B[38;5;241m.\u001B[39mthreshold)\n\u001B[1;32m     73\u001B[0m loss \u001B[38;5;241m=\u001B[39m diffusion\u001B[38;5;241m.\u001B[39mtrainloss(x_0, cemb \u001B[38;5;241m=\u001B[39m cemb)\n\u001B[0;32m---> 74\u001B[0m \u001B[43mloss\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mbackward\u001B[49m\u001B[43m(\u001B[49m\u001B[43m)\u001B[49m\n\u001B[1;32m     75\u001B[0m optimizer\u001B[38;5;241m.\u001B[39mstep()\n\u001B[1;32m     76\u001B[0m tqdmDataLoader\u001B[38;5;241m.\u001B[39mset_postfix(\n\u001B[1;32m     77\u001B[0m     ordered_dict\u001B[38;5;241m=\u001B[39m{\n\u001B[1;32m     78\u001B[0m         \u001B[38;5;124m\"\u001B[39m\u001B[38;5;124mepoch\u001B[39m\u001B[38;5;124m\"\u001B[39m: epc \u001B[38;5;241m+\u001B[39m \u001B[38;5;241m1\u001B[39m,\n\u001B[0;32m   (...)\u001B[0m\n\u001B[1;32m     83\u001B[0m     }\n\u001B[1;32m     84\u001B[0m )\n",
      "File \u001B[0;32m/root/PycharmProjects/IDGEN/venv/lib/python3.10/site-packages/torch/_tensor.py:492\u001B[0m, in \u001B[0;36mTensor.backward\u001B[0;34m(self, gradient, retain_graph, create_graph, inputs)\u001B[0m\n\u001B[1;32m    482\u001B[0m \u001B[38;5;28;01mif\u001B[39;00m has_torch_function_unary(\u001B[38;5;28mself\u001B[39m):\n\u001B[1;32m    483\u001B[0m     \u001B[38;5;28;01mreturn\u001B[39;00m handle_torch_function(\n\u001B[1;32m    484\u001B[0m         Tensor\u001B[38;5;241m.\u001B[39mbackward,\n\u001B[1;32m    485\u001B[0m         (\u001B[38;5;28mself\u001B[39m,),\n\u001B[0;32m   (...)\u001B[0m\n\u001B[1;32m    490\u001B[0m         inputs\u001B[38;5;241m=\u001B[39minputs,\n\u001B[1;32m    491\u001B[0m     )\n\u001B[0;32m--> 492\u001B[0m \u001B[43mtorch\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mautograd\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mbackward\u001B[49m\u001B[43m(\u001B[49m\n\u001B[1;32m    493\u001B[0m \u001B[43m    \u001B[49m\u001B[38;5;28;43mself\u001B[39;49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mgradient\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mretain_graph\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mcreate_graph\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43minputs\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[43minputs\u001B[49m\n\u001B[1;32m    494\u001B[0m \u001B[43m\u001B[49m\u001B[43m)\u001B[49m\n",
      "File \u001B[0;32m/root/PycharmProjects/IDGEN/venv/lib/python3.10/site-packages/torch/autograd/__init__.py:251\u001B[0m, in \u001B[0;36mbackward\u001B[0;34m(tensors, grad_tensors, retain_graph, create_graph, grad_variables, inputs)\u001B[0m\n\u001B[1;32m    246\u001B[0m     retain_graph \u001B[38;5;241m=\u001B[39m create_graph\n\u001B[1;32m    248\u001B[0m \u001B[38;5;66;03m# The reason we repeat the same comment below is that\u001B[39;00m\n\u001B[1;32m    249\u001B[0m \u001B[38;5;66;03m# some Python versions print out the first line of a multi-line function\u001B[39;00m\n\u001B[1;32m    250\u001B[0m \u001B[38;5;66;03m# calls in the traceback and some print out the last line\u001B[39;00m\n\u001B[0;32m--> 251\u001B[0m \u001B[43mVariable\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43m_execution_engine\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mrun_backward\u001B[49m\u001B[43m(\u001B[49m\u001B[43m  \u001B[49m\u001B[38;5;66;43;03m# Calls into the C++ engine to run the backward pass\u001B[39;49;00m\n\u001B[1;32m    252\u001B[0m \u001B[43m    \u001B[49m\u001B[43mtensors\u001B[49m\u001B[43m,\u001B[49m\n\u001B[1;32m    253\u001B[0m \u001B[43m    \u001B[49m\u001B[43mgrad_tensors_\u001B[49m\u001B[43m,\u001B[49m\n\u001B[1;32m    254\u001B[0m \u001B[43m    \u001B[49m\u001B[43mretain_graph\u001B[49m\u001B[43m,\u001B[49m\n\u001B[1;32m    255\u001B[0m \u001B[43m    \u001B[49m\u001B[43mcreate_graph\u001B[49m\u001B[43m,\u001B[49m\n\u001B[1;32m    256\u001B[0m \u001B[43m    \u001B[49m\u001B[43minputs\u001B[49m\u001B[43m,\u001B[49m\n\u001B[1;32m    257\u001B[0m \u001B[43m    \u001B[49m\u001B[43mallow_unreachable\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[38;5;28;43;01mTrue\u001B[39;49;00m\u001B[43m,\u001B[49m\n\u001B[1;32m    258\u001B[0m \u001B[43m    \u001B[49m\u001B[43maccumulate_grad\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[38;5;28;43;01mTrue\u001B[39;49;00m\u001B[43m,\u001B[49m\n\u001B[1;32m    259\u001B[0m \u001B[43m\u001B[49m\u001B[43m)\u001B[49m\n",
      "\u001B[0;31mKeyboardInterrupt\u001B[0m: "
     ]
    }
   ],
   "source": [
    "cemblayer = JointEmbedding2(num_labels_0=params.clsnum_0, num_labels_1=params.clsnum_1,   #user\n",
    "                           d_model=params.cdim, channels=3,\n",
    "                           dim=params.cdim, hw=32).to(device)\n",
    "\n",
    "lastpath = os.path.join(params.moddir,'last_epoch.pt')\n",
    "if os.path.exists(lastpath):\n",
    "    lastepc = torch.load(lastpath)['last_epoch']\n",
    "    # load checkpoints\n",
    "    checkpoint = torch.load(os.path.join(params.moddir, f'ckpt_{lastepc}_checkpoint.pt'), map_location='cpu')\n",
    "    net.load_state_dict(checkpoint['net'])\n",
    "    cemblayer.load_state_dict(checkpoint['cemblayer'])\n",
    "else:\n",
    "    lastepc = 0\n",
    "\n",
    "betas = get_named_beta_schedule(num_diffusion_timesteps = params.T)\n",
    "diffusion = GaussianDiffusion(\n",
    "                dtype = params.dtype,\n",
    "                model = net,\n",
    "                betas = betas,\n",
    "                w = params.w,\n",
    "                v = params.v,\n",
    "                device = device\n",
    "            )\n",
    "\n",
    "\n",
    "\n",
    "# optimizer settings\n",
    "optimizer = torch.optim.AdamW(\n",
    "                itertools.chain(\n",
    "                    diffusion.model.parameters(),\n",
    "                    cemblayer.parameters()\n",
    "                ),\n",
    "                lr = params.lr,\n",
    "                weight_decay = 1e-4\n",
    "            )\n",
    "\n",
    "cosineScheduler = optim.lr_scheduler.CosineAnnealingLR(\n",
    "                        optimizer = optimizer,\n",
    "                        T_max = params.epoch,\n",
    "                        eta_min = 0,\n",
    "                        last_epoch = -1\n",
    "                    )\n",
    "warmUpScheduler = GradualWarmupScheduler(\n",
    "                        optimizer = optimizer,\n",
    "                        multiplier = params.multiplier,\n",
    "                        warm_epoch = params.epoch // 10,\n",
    "                        after_scheduler = cosineScheduler,\n",
    "                        last_epoch = lastepc\n",
    "                    )\n",
    "if lastepc != 0:\n",
    "    optimizer.load_state_dict(checkpoint['optimizer'])\n",
    "    warmUpScheduler.load_state_dict(checkpoint['scheduler'])\n",
    "\n",
    "\n",
    "# training\n",
    "cnt = torch.cuda.device_count()\n",
    "for epc in range(lastepc, params.epoch):\n",
    "    # turn into train mode\n",
    "    diffusion.model.train()\n",
    "    cemblayer.train()\n",
    "    # batch iterations\n",
    "    with tqdm(dataloader, dynamic_ncols=True) as tqdmDataLoader:\n",
    "        for batch in tqdmDataLoader:\n",
    "            optimizer.zero_grad()\n",
    "            x_0 = datacat(batch).to(device)\n",
    "            datashape = tuple(x_0.shape[1:])\n",
    "            b = x_0.shape[0]\n",
    "            cemb = cemblayer(batch[params.condkey].to(device),  #user\n",
    "                             batch[params.labkey_0].to(device),\n",
    "                             batch[params.labkey_1].to(device))\n",
    "            cemb = F.dropout1d(cemb, params.threshold)\n",
    "\n",
    "            loss = diffusion.trainloss(x_0, cemb = cemb)\n",
    "            loss.backward()\n",
    "            optimizer.step()\n",
    "            tqdmDataLoader.set_postfix(\n",
    "                ordered_dict={\n",
    "                    \"epoch\": epc + 1,\n",
    "                    \"loss: \": loss.item(),\n",
    "                    \"batch per device: \":x_0.shape[0],\n",
    "                    \"img shape: \": x_0.shape[1:],\n",
    "                    \"LR\": optimizer.state_dict()['param_groups'][0][\"lr\"]\n",
    "                }\n",
    "            )\n",
    "\n",
    "\n",
    "    warmUpScheduler.step()\n",
    "    torch.cuda.empty_cache()\n",
    "\n",
    "    if (epc + 1) % params.interval == 0:\n",
    "        os.makedirs(params.moddir, exist_ok=True)\n",
    "        os.makedirs(params.samdir, exist_ok=True)\n",
    "\n",
    "        diffusion.model.eval()\n",
    "        cemblayer.eval()\n",
    "        # generating samples\n",
    "        # Generates genbatch pictures in 2 columns\n",
    "        # column 0: conditioning image\n",
    "        # column 1: generated image (1)\n",
    "        # column 2: generated image (2)\n",
    "\n",
    "\n",
    "        all_conds = []\n",
    "        all_samples = []\n",
    "        each_device_batch = params.genbatch // cnt\n",
    "        val_batch = next(val_cycler)\n",
    "        with torch.no_grad():\n",
    "            cond = val_batch[params.condkey].to(device)\n",
    "            cemb = cemblayer(cond,                          #user\n",
    "                             val_batch[params.labkey_0].to(device),\n",
    "                             val_batch[params.labkey_1].to(device))\n",
    "            genshape = (each_device_batch ,) + datashape\n",
    "            if params.ddim:\n",
    "                generated = diffusion.ddim_sample(genshape, params.num_steps, params.eta, params.select, cemb = cemb)\n",
    "            else:\n",
    "                generated = diffusion.sample(genshape, cemb = cemb)\n",
    "\n",
    "            cond = transback(cond)\n",
    "            img = transback(generated)\n",
    "\n",
    "\n",
    "            final_imgs = torch.cat([cond, img], dim=1) #(b, 9, 32, 32)   #user\n",
    "            final_imgs = final_imgs.reshape(-1, 3, 32, 32).contiguous()\n",
    "            save_image(final_imgs, os.path.join(params.samdir, f'generated_{epc+1}_pict.png'), nrow = 3)\n",
    "            print('Image saved as ',os.path.join(params.samdir, f'generated_{epc+1}_pict.png'))\n",
    "\n",
    "\n",
    "        # save checkpoints\n",
    "        checkpoint = {\n",
    "                    'net':diffusion.model.state_dict(),\n",
    "                    'cemblayer':cemblayer.state_dict(),\n",
    "                    'optimizer':optimizer.state_dict(),\n",
    "                    'scheduler':warmUpScheduler.state_dict()\n",
    "                }\n",
    "        torch.save({'last_epoch':epc+1}, os.path.join(params.moddir,'last_epoch.pt'))\n",
    "        torch.save(checkpoint, os.path.join(params.moddir, f'ckpt_{epc+1}_checkpoint.pt'))\n",
    "        print('Model saved as ',os.path.join(params.moddir, f'ckpt_{epc+1}_checkpoint.pt'))\n",
    "\n",
    "\n",
    "    torch.cuda.empty_cache()\n"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [],
   "metadata": {
    "collapsed": false
   }
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}
