{
    "cells": [
        {
            "cell_type": "code",
            "execution_count": 1,
            "id": "9007660b",
            "metadata": {},
            "outputs": [],
            "source": [
                "# ---------------------------------------------------------------\n",
                "# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.\n",
                "#\n",
                "# This work is licensed under the NVIDIA Source Code License\n",
                "# for Denoising Diffusion GAN. To view a copy of this license, see the LICENSE file.\n",
                "# ---------------------------------------------------------------\n",
                "\n",
                "\n",
                "import argparse\n",
                "import torch\n",
                "import numpy as np\n",
                "import matplotlib.pyplot as plt\n",
                "import copy\n",
                "\n",
                "import os\n",
                "\n",
                "import torch.nn as nn\n",
                "import torch.nn.functional as F\n",
                "import torch.optim as optim\n",
                "import torchvision\n",
                "from torchmetrics.image.fid import FrechetInceptionDistance\n",
                "\n",
                "import torchvision.transforms as transforms\n",
                "from torchvision.datasets import CIFAR10\n",
                "from torchvision import datasets\n",
                "from datasets_prep.lsun import LSUN\n",
                "from datasets_prep.stackmnist_data import StackedMNIST, _data_transforms_stacked_mnist\n",
                "from datasets_prep.lmdb_datasets import LMDBDataset\n",
                "from torch.utils.data import TensorDataset\n",
                "from discrete_ot import OTPlanSampler\n",
                "\n",
                "\n",
                "from torch.multiprocessing import Process\n",
                "import torch.distributed as dist\n",
                "import shutil\n",
                "import pdb\n",
                "\n",
                "import ssl\n",
                "ssl._create_default_https_context = ssl._create_unverified_context\n",
                "\n",
                "import wandb\n",
                "\n",
                "from torch_ema import ExponentialMovingAverage\n",
                "\n",
                "from datetime import datetime\n",
                "datetime_marker_str = datetime.now().strftime(\"%d:%m:%y_%H:%M:%S\")"
            ]
        },
        {
            "cell_type": "markdown",
            "id": "737d0314",
            "metadata": {},
            "source": [
                "## Parameters"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "id": "b329b4c7",
            "metadata": {
                "tags": [
                    "parameters"
                ]
            },
            "outputs": [],
            "source": [
                "\n",
                "fid_n_samples = 100\n",
                "eps = 1\n",
                "\n",
                "T = 4\n",
                "exp_name = 'BMGAN_ColoredMNIST'\n",
                "mini_batch_OT = False\n",
                "plan = 'ind'\n",
                "\n",
                "D_opt_steps = 2\n",
                "ema_decay = 0.99\n",
                "\n",
                "ipmf_iters = 20\n",
                "markovian_proj_iters = 30000\n",
                "\n",
                "inner_ipmf_mark_proj_iters = 10000\n",
                "\n",
                "mini_batch_OT = 0\n",
                "\n",
                "ema_start_ipmf = 0\n",
                "\n",
                "device = 'cuda:0'\n"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "id": "bd87e646",
            "metadata": {},
            "outputs": [],
            "source": [
                "if mini_batch_OT > 0:\n",
                "    mini_batch_OT = True\n",
                "else:\n",
                "    mini_batch_OT = False\n",
                "\n",
                "if ema_start_ipmf > 0:\n",
                "    ema_start_ipmf = True\n",
                "else:\n",
                "    ema_start_ipmf = False\n"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "id": "ac428edc",
            "metadata": {},
            "outputs": [],
            "source": [
                "\n",
                "config = {'eps': eps, 'T': T, 'mini_batch_OT': mini_batch_OT, 'D_opt_steps': D_opt_steps, 'ema_decay': ema_decay,\n",
                "'ipmf_iters': ipmf_iters, 'markovian_proj_iters': markovian_proj_iters, \n",
                "'inner_ipmf_mark_proj_iters': inner_ipmf_mark_proj_iters, 'ema_start_ipmf': ema_start_ipmf, 'plan': plan}\n"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "id": "a04232f8",
            "metadata": {},
            "outputs": [],
            "source": [
                "\n",
                "def config_to_expdir(config):\n",
                "    exp_dir = datetime_marker_str\n",
                "    \n",
                "    for key, value in config.items():\n",
                "        exp_dir += str(key) + '=' + str(value) + '\\\\'\n",
                "        \n",
                "    return exp_dir\n",
                "\n",
                "save_dir = config_to_expdir(config)\n",
                "\n",
                "os.makedirs(os.path.join(exp_name, save_dir))\n"
            ]
        },
        {
            "cell_type": "markdown",
            "id": "a3617af0",
            "metadata": {},
            "source": [
                "## Helper functions and posterior sampling\n",
                "\n",
                "including p(x_t | x_0, x_1) from Brownian Bridge"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "id": "44f6fda5",
            "metadata": {},
            "outputs": [],
            "source": [
                "\n",
                "# FID calculation\n",
                "def normalize_tensor(tensor):\n",
                "    normalized = tensor / 2 + 0.5\n",
                "    return normalized.clamp_(0, 1)\n",
                "\n",
                "\n",
                "def to_uint8_tensor(tensor):\n",
                "    tensor = normalize_tensor(tensor)\n",
                "    return tensor.mul(255).add_(0.5).clamp_(0, 255).to(torch.uint8)\n",
                "\n",
                "\n",
                "def compute_fid_and_ot_cost(true_dataloader, model_input_dataloader, sample_fn):\n",
                "    # backward loader y -> x\n",
                "    ot_cost = 0\n",
                "    fid = FrechetInceptionDistance().to(device)\n",
                "    \n",
                "    for item in iter(true_dataloader):\n",
                "        x = item[0]\n",
                "        fid.update(to_uint8_tensor(x.expand(-1, 3, -1, -1)).to(device), real=True)\n",
                "\n",
                "    for item in iter(model_input_dataloader):\n",
                "        y = item[0]\n",
                "        fake_sample = sample_fn(y.to(device))\n",
                "        fid.update(to_uint8_tensor(fake_sample.expand(-1, 3, -1, -1)), real=False)\n",
                "\n",
                "        ot_cost += F.mse_loss(fake_sample.to(device), y.to(device)) * y.shape[0]\n",
                "\n",
                "    ot_cost = ot_cost / 10000\n",
                "        \n",
                "    return fid.compute(), ot_cost\n",
                "        "
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "id": "3baed17f",
            "metadata": {},
            "outputs": [],
            "source": [
                "\n",
                "def copy_source(file, output_dir):\n",
                "    shutil.copyfile(file, os.path.join(output_dir, os.path.basename(file)))\n",
                "            \n",
                "def broadcast_params(params):\n",
                "    for param in params:\n",
                "        dist.broadcast(param.data, src=0)\n",
                "\n",
                "\n",
                "#%% Diffusion coefficients \n",
                "def var_func_vp(t, beta_min, beta_max):\n",
                "    log_mean_coeff = -0.25 * t ** 2 * (beta_max - beta_min) - 0.5 * t * beta_min\n",
                "    var = 1. - torch.exp(2. * log_mean_coeff)\n",
                "    return var\n",
                "\n",
                "def var_func_geometric(t, beta_min, beta_max):\n",
                "    return beta_min * ((beta_max / beta_min) ** t)\n",
                "\n",
                "def extract(input, t, shape):\n",
                "    out = torch.gather(input, 0, t)\n",
                "    reshape = [shape[0]] + [1] * (len(shape) - 1)\n",
                "    out = out.reshape(*reshape)\n",
                "\n",
                "    return out\n",
                "\n",
                "def get_time_schedule(args, device):\n",
                "    n_timestep = args.num_timesteps\n",
                "    eps_small = 1e-3\n",
                "    t = np.arange(0, n_timestep + 1, dtype=np.float64)\n",
                "    t = t / n_timestep\n",
                "    t = torch.from_numpy(t) * (1. - eps_small)  + eps_small\n",
                "    return t.to(device)\n",
                "\n",
                "def get_sigma_schedule(args, device):\n",
                "    n_timestep = args.num_timesteps\n",
                "    beta_min = args.beta_min\n",
                "    beta_max = args.beta_max\n",
                "    eps_small = 1e-3\n",
                "   \n",
                "    t = np.arange(0, n_timestep + 1, dtype=np.float64)\n",
                "    t = t / n_timestep\n",
                "    t = torch.from_numpy(t) * (1. - eps_small) + eps_small\n",
                "    \n",
                "    if args.use_geometric:\n",
                "        var = var_func_geometric(t, beta_min, beta_max)\n",
                "    else:\n",
                "        var = var_func_vp(t, beta_min, beta_max)\n",
                "    alpha_bars = 1.0 - var\n",
                "    betas = 1 - alpha_bars[1:] / alpha_bars[:-1]\n",
                "    \n",
                "    first = torch.tensor(1e-8)\n",
                "    betas = torch.cat((first[None], betas)).to(device)\n",
                "    betas = betas.type(torch.float32)\n",
                "    sigmas = betas**0.5\n",
                "    a_s = torch.sqrt(1-betas)\n",
                "    return sigmas, a_s, betas\n",
                "\n",
                "class Diffusion_Coefficients():\n",
                "    def __init__(self, args, device):\n",
                "                \n",
                "        self.sigmas, self.a_s, _ = get_sigma_schedule(args, device=device)\n",
                "        self.a_s_cum = np.cumprod(self.a_s.cpu())\n",
                "        self.sigmas_cum = np.sqrt(1 - self.a_s_cum ** 2)\n",
                "        self.a_s_prev = self.a_s.clone()\n",
                "        self.a_s_prev[-1] = 1\n",
                "        \n",
                "        self.a_s_cum = self.a_s_cum.to(device)\n",
                "        self.sigmas_cum = self.sigmas_cum.to(device)\n",
                "        self.a_s_prev = self.a_s_prev.to(device)\n",
                "    \n",
                "def q_sample(coeff, x_start, t, *, noise=None):\n",
                "    \"\"\"\n",
                "    Diffuse the data (t == 0 means diffused for t step)\n",
                "    \"\"\"\n",
                "    if noise is None:\n",
                "      noise = torch.randn_like(x_start)\n",
                "      \n",
                "    x_t = extract(coeff.a_s_cum, t, x_start.shape) * x_start + \\\n",
                "          extract(coeff.sigmas_cum, t, x_start.shape) * noise\n",
                "    \n",
                "    return x_t\n",
                "\n",
                "def q_sample_supervised(pos_coeff, x_start, t, x_end, *, noise=None):\n",
                "    \"\"\"\n",
                "    Diffuse the data (t == 0 means diffused for t step)\n",
                "    \"\"\"\n",
                "    if noise is None:\n",
                "      noise = torch.randn_like(x_start)\n",
                "\n",
                "    T = len(coeff.a_s_cum)\n",
                "\n",
                "    x_t = x_end\n",
                "    for t_current in reversed(list(range(t[0], T))):\n",
                "        t_tensor = torch.full((x_t.size(0),), t_current, dtype=torch.int64).to(x_t.device)\n",
                "        x_t = sample_posterior(pos_coeff, x_start, x_t, t_tensor)\n",
                "    \n",
                "    return x_t\n",
                "\n",
                "def q_sample_pairs(coeff, x_start, t):\n",
                "    \"\"\"\n",
                "    Generate a pair of disturbed images for training\n",
                "    :param x_start: x_0\n",
                "    :param t: time step t\n",
                "    :return: x_t, x_{t+1}\n",
                "    \"\"\"\n",
                "    noise = torch.randn_like(x_start)\n",
                "    x_t = q_sample(coeff, x_start, t)\n",
                "    x_t_plus_one = extract(coeff.a_s, t+1, x_start.shape) * x_t + \\\n",
                "                   extract(coeff.sigmas, t+1, x_start.shape) * noise\n",
                "    \n",
                "    return x_t, x_t_plus_one\n",
                "\n",
                "def q_sample_supervised_pairs(pos_coeff, x_start, t, x_end):\n",
                "    \"\"\"\n",
                "    Generate a pair of disturbed images for training\n",
                "    :param x_start: x_0\n",
                "    :param t: time step t\n",
                "    :return: x_t, x_{t+1}\n",
                "    \"\"\"\n",
                "#     noise = torch.randn_like(x_start)\n",
                "    T = pos_coeff.posterior_mean_coef1.shape[0]\n",
                "\n",
                "    x_t_plus_one = x_end\n",
                "    t_current = T\n",
                "\n",
                "    while t_current != t[0]:\n",
                "        t_tensor = torch.full((x_end.size(0),), t_current-1, dtype=torch.int64).to(x_end.device)\n",
                "        x_t_plus_one = sample_posterior(pos_coeff, x_start, x_t_plus_one, t_tensor)\n",
                "        t_current -= 1\n",
                "\n",
                "    t_tensor = torch.full((x_end.size(0),), t_current, dtype=torch.int64).to(x_end.device)\n",
                "    x_t = sample_posterior(pos_coeff, x_start, x_t_plus_one, t_tensor)\n",
                "    \n",
                "    return x_t, x_t_plus_one\n",
                "\n",
                "\n",
                "def q_sample_supervised_pairs_brownian(pos_coeff, x_start, t, x_end):\n",
                "    \"\"\"\n",
                "    Generate a pair of disturbed images for training\n",
                "    :param x_start: x_0\n",
                "    :param t: time step t\n",
                "    :return: x_t, x_{t+1}\n",
                "    \"\"\"\n",
                "    noise = torch.randn_like(x_start)\n",
                "    num_steps = pos_coeff.posterior_mean_coef1.shape[0]\n",
                "    t_plus_one_tensor = ((t+1)/num_steps)[:, None, None, None]\n",
                "\n",
                "    x_t_plus_one = t_plus_one_tensor*x_end + (1.0 - t_plus_one_tensor)*x_start + torch.sqrt(pos_coeff.epsilon*t_plus_one_tensor*(1-t_plus_one_tensor))*noise\n",
                "    \n",
                "    x_t = sample_posterior(pos_coeff, x_start, x_t_plus_one, t)\n",
                "    \n",
                "    return x_t, x_t_plus_one\n",
                "\n",
                "\n",
                "def q_sample_supervised_trajectory(pos_coeff, x_start, x_end):\n",
                "    \"\"\"\n",
                "    Generate a pair of disturbed images for training\n",
                "    :param x_start: x_0\n",
                "    :param t: time step t\n",
                "    :return: x_t, x_{t+1}\n",
                "    \"\"\"\n",
                "#     noise = torch.randn_like(x_start)\n",
                "    trajectory = [x_end]\n",
                "    T = pos_coeff.posterior_mean_coef1.shape[0]\n",
                "\n",
                "    x_t_plus_one = x_end\n",
                "    t_current = T\n",
                "\n",
                "    while t_current != 0:\n",
                "        t_tensor = torch.full((x_end.size(0),), t_current-1, dtype=torch.int64).to(x_end.device)\n",
                "        x_t_plus_one = sample_posterior(pos_coeff, x_start, x_t_plus_one, t_tensor)\n",
                "        t_current -= 1\n",
                "        trajectory.append(x_t_plus_one)\n",
                "\n",
                "    t_tensor = torch.full((x_end.size(0),), t_current, dtype=torch.int64).to(x_end.device)\n",
                "    x_t = sample_posterior(pos_coeff, x_start, x_t_plus_one, t_tensor)\n",
                "    trajectory.append(x_t)\n",
                "    \n",
                "    return trajectory\n",
                "\n",
                "#%% posterior sampling\n",
                "class Posterior_Coefficients():\n",
                "    def __init__(self, args, device):\n",
                "        \n",
                "        _, _, self.betas = get_sigma_schedule(args, device=device)\n",
                "        \n",
                "        #we don't need the zeros\n",
                "        self.betas = self.betas.type(torch.float32)[1:]\n",
                "        \n",
                "        self.alphas = 1 - self.betas\n",
                "        self.alphas_cumprod = torch.cumprod(self.alphas, 0)\n",
                "        self.alphas_cumprod_prev = torch.cat(\n",
                "                                    (torch.tensor([1.], dtype=torch.float32,device=device), self.alphas_cumprod[:-1]), 0\n",
                "                                        )               \n",
                "        self.posterior_variance = self.betas * (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod)\n",
                "        \n",
                "        self.sqrt_alphas_cumprod = torch.sqrt(self.alphas_cumprod)\n",
                "        self.sqrt_recip_alphas_cumprod = torch.rsqrt(self.alphas_cumprod)\n",
                "        self.sqrt_recipm1_alphas_cumprod = torch.sqrt(1 / self.alphas_cumprod - 1)\n",
                "        \n",
                "        self.posterior_mean_coef1 = (self.betas * torch.sqrt(self.alphas_cumprod_prev) / (1 - self.alphas_cumprod))\n",
                "        self.posterior_mean_coef2 = ((1 - self.alphas_cumprod_prev) * torch.sqrt(self.alphas) / (1 - self.alphas_cumprod))\n",
                "        \n",
                "        self.posterior_log_variance_clipped = torch.log(self.posterior_variance.clamp(min=1e-20))\n",
                "\n",
                "\n",
                "class BrownianPosterior_Coefficients():\n",
                "    def __init__(self, args, device):\n",
                "        epsilon = args.epsilon\n",
                "        self.epsilon = epsilon\n",
                "        num_timesteps = args.num_timesteps\n",
                "\n",
                "        t = torch.linspace(0, 1, num_timesteps+1, device=device)\n",
                "        self.posterior_mean_coef1 = 1 - t[:-1]/t[1:]\n",
                "        self.posterior_mean_coef2 = t[:-1]/t[1:]\n",
                "\n",
                "        self.posterior_variance = epsilon*t[:-1]*(t[1:] - t[:-1])/t[1:]\n",
                "        self.posterior_log_variance_clipped = torch.log(self.posterior_variance.clamp(min=1e-20))\n",
                "\n",
                "        \n",
                "def sample_posterior(coefficients, x_0, x_t, t):\n",
                "    \n",
                "    def q_posterior(x_0, x_t, t):\n",
                "        mean = (\n",
                "            extract(coefficients.posterior_mean_coef1, t, x_t.shape) * x_0\n",
                "            + extract(coefficients.posterior_mean_coef2, t, x_t.shape) * x_t\n",
                "        )\n",
                "        var = extract(coefficients.posterior_variance, t, x_t.shape)\n",
                "        log_var_clipped = extract(coefficients.posterior_log_variance_clipped, t, x_t.shape)\n",
                "        return mean, var, log_var_clipped\n",
                "    \n",
                "  \n",
                "    def p_sample(x_0, x_t, t):\n",
                "        mean, _, log_var = q_posterior(x_0, x_t, t)\n",
                "        \n",
                "        noise = torch.randn_like(x_t)\n",
                "        \n",
                "        nonzero_mask = (1 - (t == 0).type(torch.float32))\n",
                "\n",
                "        return mean + nonzero_mask[:,None,None,None] * torch.exp(0.5 * log_var) * noise\n",
                "            \n",
                "    sample_x_pos = p_sample(x_0, x_t, t)\n",
                "    \n",
                "    return sample_x_pos\n",
                "\n",
                "def sample_from_model(coefficients, generator, n_time, x_init, T, opt, return_trajectory=False):\n",
                "    x = x_init\n",
                "    trajectory = [x]\n",
                "    x_0_pred_list = []\n",
                "    with torch.no_grad():\n",
                "        for i in reversed(range(n_time)):\n",
                "            t = torch.full((x.size(0),), i, dtype=torch.int64).to(x.device)\n",
                "          \n",
                "            t_time = t\n",
                "            latent_z = torch.randn(x.size(0), opt.nz, device=x.device)\n",
                "            x_0 = generator(x, t_time, latent_z)\n",
                "            x_new = sample_posterior(coefficients, x_0, x, t)\n",
                "            x = x_new.detach()\n",
                "            \n",
                "            x_0_pred_list.append(x_0.detach())\n",
                "            trajectory.append(x)\n",
                "\n",
                "    if return_trajectory:\n",
                "        return x, x_0_pred_list, trajectory\n",
                "    \n",
                "    return x\n",
                "\n",
                "\n",
                "def get_random_colored_images(images, seed=None):\n",
                "    if seed is not None:\n",
                "        np.random.seed(seed)\n",
                "    \n",
                "    images = 0.5*(images + 1)\n",
                "    size = images.shape[0]\n",
                "    colored_images = []\n",
                "    hues = 360*np.random.rand(size)\n",
                "    \n",
                "    for V, H in zip(images, hues):\n",
                "        V_min = 0\n",
                "        \n",
                "        a = (V - V_min)*(H%60)/60\n",
                "        V_inc = a\n",
                "        V_dec = V - a\n",
                "        \n",
                "        colored_image = torch.zeros((3, V.shape[1], V.shape[2]))\n",
                "        H_i = round(H/60) % 6\n",
                "        \n",
                "        if H_i == 0:\n",
                "            colored_image[0] = V\n",
                "            colored_image[1] = V_inc\n",
                "            colored_image[2] = V_min\n",
                "        elif H_i == 1:\n",
                "            colored_image[0] = V_dec\n",
                "            colored_image[1] = V\n",
                "            colored_image[2] = V_min\n",
                "        elif H_i == 2:\n",
                "            colored_image[0] = V_min\n",
                "            colored_image[1] = V\n",
                "            colored_image[2] = V_inc\n",
                "        elif H_i == 3:\n",
                "            colored_image[0] = V_min\n",
                "            colored_image[1] = V_dec\n",
                "            colored_image[2] = V\n",
                "        elif H_i == 4:\n",
                "            colored_image[0] = V_inc\n",
                "            colored_image[1] = V_min\n",
                "            colored_image[2] = V\n",
                "        elif H_i == 5:\n",
                "            colored_image[0] = V\n",
                "            colored_image[1] = V_min\n",
                "            colored_image[2] = V_dec\n",
                "        \n",
                "        colored_images.append(colored_image)\n",
                "        \n",
                "    colored_images = torch.stack(colored_images, dim = 0)\n",
                "    colored_images = 2*colored_images - 1\n",
                "    \n",
                "    return colored_images\n",
                "\n",
                "from torchvision import datasets\n",
                "from torch.utils.data import TensorDataset\n",
                "\n",
                "def load_paired_colored_mnist_one_side(target_number=2, train=True, seed=None, dataset_size=None):\n",
                "    transform = torchvision.transforms.Compose([\n",
                "        torchvision.transforms.Resize((32, 32)),\n",
                "        torchvision.transforms.ToTensor(),\n",
                "        torchvision.transforms.Lambda(lambda x: 2 * x - 1)\n",
                "    ])\n",
                "    \n",
                "    dataset = datasets.MNIST(\"./\", train=train, transform=transform, download=True)\n",
                "\n",
                "    digits = torch.stack(\n",
                "        [dataset[i][0] for i in range(len(dataset.targets)) if dataset.targets[i] == target_number],\n",
                "        dim=0\n",
                "    )\n",
                "    \n",
                "    digits = digits.reshape(-1, 1, 32, 32)\n",
                "\n",
                "    if dataset_size is not None:\n",
                "        if digits.shape[0] < dataset_size:\n",
                "            digits = digits.repeat([dataset_size // digits.shape[0] + 1, 1, 1, 1])[:dataset_size]\n",
                "\n",
                "    digits_colored = get_random_colored_images(digits, seed=seed)\n",
                "    \n",
                "    size = digits_colored.shape[0]\n",
                "    \n",
                "    dataset = TensorDataset(digits_colored, torch.zeros_like(digits_colored))\n",
                "    \n",
                "    return dataset\n",
                "\n",
                "def load_paired_colored_mnist():\n",
                "    transform = torchvision.transforms.Compose([\n",
                "        torchvision.transforms.Resize((32, 32)),\n",
                "        torchvision.transforms.ToTensor(),\n",
                "        torchvision.transforms.Lambda(lambda x: 2 * x - 1)\n",
                "    ])\n",
                "    \n",
                "    train_set = datasets.MNIST(\"./\", train=True, transform=transform, download=True)\n",
                "    test_set = datasets.MNIST(\"./\", train=False, transform=transform, download=True)\n",
                "    \n",
                "    x = []\n",
                "    y = []\n",
                "\n",
                "    digits_2 = torch.stack(\n",
                "            [train_set[i][0] for i in range(len(train_set.targets)) if train_set.targets[i] == 2],\n",
                "            dim=0\n",
                "        )\n",
                "    digits_2 = digits_2.reshape(-1, 1, 32, 32)\n",
                "    digits_2_colored = get_random_colored_images(digits_2)\n",
                "    \n",
                "    digits_3 = torch.stack(\n",
                "            [train_set[i][0] for i in range(len(train_set.targets)) if train_set.targets[i] == 3],\n",
                "            dim=0\n",
                "        )\n",
                "    digits_3 = digits_3.reshape(-1, 1, 32, 32)\n",
                "    digits_3_colored = get_random_colored_images(digits_3)\n",
                "\n",
                "    size = min(digits_2_colored.shape[0], digits_3_colored.shape[0])\n",
                "    \n",
                "    dataset = TensorDataset(digits_2_colored[:size], digits_3_colored[:size])\n",
                "    \n",
                "    return dataset\n",
                "    \n",
                "def load_prior_paired_colored_mnist(num_0=2, num_1=3, transform=None):\n",
                "    transform_ = torchvision.transforms.Compose([\n",
                "        torchvision.transforms.Resize((32, 32)),\n",
                "        torchvision.transforms.ToTensor(),\n",
                "        torchvision.transforms.Lambda(lambda x: 2 * x - 1)\n",
                "    ])\n",
                "    \n",
                "    train_set = datasets.MNIST(\"./\", train=True, transform=transform_, download=True)\n",
                "    test_set = datasets.MNIST(\"./\", train=False, transform=transform_, download=True)\n",
                "\n",
                "    digits_0 = torch.stack(\n",
                "            [train_set[i][0] for i in range(len(train_set.targets)) if train_set.targets[i] == num_0],\n",
                "            dim=0\n",
                "        )\n",
                "    digits_0 = digits_0.reshape(-1, 1, 32, 32)\n",
                "    digits_0_colored = get_random_colored_images(digits_0)\n",
                "   \n",
                "    digits_1 = torch.stack(\n",
                "            [train_set[i][0] for i in range(len(train_set.targets)) if train_set.targets[i] == num_1],\n",
                "            dim=0\n",
                "        )\n",
                "    digits_1 = digits_1.reshape(-1, 1, 32, 32)\n",
                "    digits_1_colored = get_random_colored_images(digits_1)\n",
                "\n",
                "    if transform is not None:\n",
                "        digits_0_colored, digits_1_colored = transform(digits_0_colored, digits_1_colored)\n",
                "\n",
                "    size = min(digits_0_colored.shape[0], digits_1_colored.shape[0])\n",
                "    dataset = TensorDataset(digits_0_colored[:size], digits_1_colored[:size])\n",
                "    \n",
                "    return dataset"
            ]
        },
        {
            "cell_type": "markdown",
            "id": "b3df853c",
            "metadata": {},
            "source": [
                "## Loading config with DDGAN base parameters"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "id": "04060075",
            "metadata": {},
            "outputs": [],
            "source": [
                "parser = argparse.ArgumentParser('ddgan parameters')\n",
                "parser.add_argument('--seed', type=int, default=1024,\n",
                "                    help='seed used for initialization')\n",
                "\n",
                "parser.add_argument('--plan', type=str, default='ind',\n",
                "                    help='Init type')\n",
                "\n",
                "parser.add_argument('--resume', action='store_true',default=False)\n",
                "\n",
                "parser.add_argument('--image_size', type=int, default=32,\n",
                "                        help='size of image')\n",
                "parser.add_argument('--num_channels', type=int, default=3,\n",
                "                        help='channel of image')\n",
                "parser.add_argument('--centered', action='store_false', default=True,\n",
                "                        help='-1,1 scale')\n",
                "\n",
                "parser.add_argument('--posterior', type=str, default='ddpm',\n",
                "                    help='type of posterior to use')\n",
                "\n",
                "# ddpm prior\n",
                "parser.add_argument('--use_geometric', action='store_true',default=False)\n",
                "parser.add_argument('--beta_min', type=float, default= 0.1,\n",
                "                        help='beta_min for diffusion')\n",
                "parser.add_argument('--beta_max', type=float, default=20.,\n",
                "                        help='beta_max for diffusion')\n",
                "\n",
                "# brownian bridge prior\n",
                "parser.add_argument('--epsilon', type=float, default=1.0,\n",
                "                        help='variance of brownian bridge')\n",
                "\n",
                "parser.add_argument('--num_channels_dae', type=int, default=128,\n",
                "                        help='number of initial channels in denosing model')\n",
                "parser.add_argument('--n_mlp', type=int, default=3,\n",
                "                        help='number of mlp layers for z')\n",
                "parser.add_argument('--ch_mult', nargs='+', type=int,\n",
                "                        help='channel multiplier')\n",
                "parser.add_argument('--num_res_blocks', type=int, default=2,\n",
                "                        help='number of resnet blocks per scale')\n",
                "parser.add_argument('--attn_resolutions', default=(16,),\n",
                "                        help='resolution of applying attention')\n",
                "parser.add_argument('--dropout', type=float, default=0.,\n",
                "                        help='drop-out rate')\n",
                "parser.add_argument('--resamp_with_conv', action='store_false', default=True,\n",
                "                        help='always up/down sampling with conv')\n",
                "parser.add_argument('--conditional', action='store_false', default=True,\n",
                "                        help='noise conditional')\n",
                "parser.add_argument('--fir', action='store_false', default=True,\n",
                "                        help='FIR')\n",
                "parser.add_argument('--fir_kernel', default=[1, 3, 3, 1],\n",
                "                        help='FIR kernel')\n",
                "parser.add_argument('--skip_rescale', action='store_false', default=True,\n",
                "                        help='skip rescale')\n",
                "parser.add_argument('--resblock_type', default='biggan',\n",
                "                        help='tyle of resnet block, choice in biggan and ddpm')\n",
                "parser.add_argument('--progressive', type=str, default='none', choices=['none', 'output_skip', 'residual'],\n",
                "                        help='progressive type for output')\n",
                "parser.add_argument('--progressive_input', type=str, default='residual', choices=['none', 'input_skip', 'residual'],\n",
                "                    help='progressive type for input')\n",
                "parser.add_argument('--progressive_combine', type=str, default='sum', choices=['sum', 'cat'],\n",
                "                    help='progressive combine method.')\n",
                "\n",
                "parser.add_argument('--embedding_type', type=str, default='positional', choices=['positional', 'fourier'],\n",
                "                    help='type of time embedding')\n",
                "parser.add_argument('--fourier_scale', type=float, default=16.,\n",
                "                        help='scale of fourier transform')\n",
                "parser.add_argument('--not_use_tanh', action='store_true',default=False)\n",
                "\n",
                "#geenrator and training\n",
                "parser.add_argument('--exp', default='experiment_cifar_default', help='name of experiment')\n",
                "parser.add_argument('--dataset', default='cifar10', help='name of dataset')\n",
                "parser.add_argument('--nz', type=int, default=100)\n",
                "parser.add_argument('--num_timesteps', type=int, default=4)\n",
                "\n",
                "parser.add_argument('--z_emb_dim', type=int, default=256)\n",
                "parser.add_argument('--t_emb_dim', type=int, default=256)\n",
                "parser.add_argument('--batch_size', type=int, default=128, help='input batch size')\n",
                "parser.add_argument('--num_epoch', type=int, default=1200)\n",
                "parser.add_argument('--ngf', type=int, default=64)\n",
                "\n",
                "parser.add_argument('--lr_g', type=float, default=1.5e-4, help='learning rate g')\n",
                "parser.add_argument('--lr_d', type=float, default=1e-4, help='learning rate d')\n",
                "parser.add_argument('--beta1', type=float, default=0.5,\n",
                "                        help='beta1 for adam')\n",
                "parser.add_argument('--beta2', type=float, default=0.9,\n",
                "                        help='beta2 for adam')\n",
                "parser.add_argument('--no_lr_decay',action='store_true', default=False)\n",
                "\n",
                "parser.add_argument('--use_ema', action='store_true', default=False,\n",
                "                        help='use EMA or not')\n",
                "parser.add_argument('--ema_decay', type=float, default=0.9999, help='decay rate for EMA')\n",
                "\n",
                "parser.add_argument('--r1_gamma', type=float, default=0.05, help='coef for r1 reg')\n",
                "parser.add_argument('--lazy_reg', type=int, default=None,\n",
                "                    help='lazy regulariation.')\n",
                "\n",
                "parser.add_argument('--save_content', action='store_true',default=False)\n",
                "parser.add_argument('--save_content_every', type=int, default=50, help='save content for resuming every x epochs')\n",
                "parser.add_argument('--save_ckpt_every', type=int, default=25, help='save ckpt every x epochs')\n",
                "\n",
                "###ddp\n",
                "parser.add_argument('--num_proc_node', type=int, default=1,\n",
                "                    help='The number of nodes in multi node env.')\n",
                "parser.add_argument('--num_process_per_node', type=int, default=1,\n",
                "                    help='number of gpus')\n",
                "parser.add_argument('--node_rank', type=int, default=0,\n",
                "                    help='The index of node.')\n",
                "parser.add_argument('--local_rank', type=int, default=0,\n",
                "                    help='rank of process in the node')\n",
                "parser.add_argument('--master_address', type=str, default='127.0.0.1',\n",
                "                    help='address for master')\n",
                "\n",
                "cli_params  = f'--plan {plan} --dataset paired_colored_mnist --num_timesteps {T} --exp ddgan_colored_mnist --num_channels 3 --num_channels_dae 128 --num_res_blocks 2 --batch_size 64 --num_epoch 1800 --ngf 64 --nz 100 --z_emb_dim 256 --n_mlp 4 --embedding_type positional --r1_gamma 0.02 --lr_d 1.25e-4 --lr_g 1.6e-4 --lazy_reg 15 --num_process_per_node 1 --ch_mult 1 2 2 2 --save_content --posterior brownian_bridge --epsilon {eps}'\n",
                "\n",
                "args = parser.parse_args(cli_params.split(' '))\n",
                "print(args.plan)\n",
                "\n",
                "args.world_size = args.num_proc_node * args.num_process_per_node\n",
                "size = args.num_process_per_node\n",
                "\n",
                "# if not args.use_ema:\n",
                "#     args.ema_decay = 0\n",
                "    \n",
                "if ema_decay > 0:\n",
                "    args.use_ema = True\n",
                "args.ema_decay = ema_decay\n"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "id": "4d611de1",
            "metadata": {},
            "outputs": [],
            "source": [
                "\n",
                "# train(rank=0, gpu=0, args=args)\n",
                "\n",
                "rank = 0\n",
                "gpu = 0\n"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "id": "2ac33b69",
            "metadata": {},
            "outputs": [],
            "source": [
                "\n",
                "from score_sde.models.discriminator import Discriminator_small, Discriminator_large\n",
                "from score_sde.models.ncsnpp_generator_adagn import NCSNpp\n",
                "\n",
                "torch.manual_seed(args.seed + rank)\n",
                "torch.cuda.manual_seed(args.seed + rank)\n",
                "torch.cuda.manual_seed_all(args.seed + rank)\n",
                "device = torch.device('cuda:{}'.format(gpu))\n",
                "\n",
                "batch_size = args.batch_size\n",
                "\n",
                "nz = args.nz #latent dimension\n",
                "\n",
                "if args.dataset == 'cifar10':\n",
                "    dataset = CIFAR10('./data', train=True, transform=transforms.Compose([\n",
                "                    transforms.Resize(32),\n",
                "                    transforms.RandomHorizontalFlip(),\n",
                "                    transforms.ToTensor(),\n",
                "                    transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5))]), download=True)\n",
                "\n",
                "\n",
                "elif args.dataset == 'stackmnist':\n",
                "    train_transform, valid_transform = _data_transforms_stacked_mnist()\n",
                "    dataset = StackedMNIST(root='./data', train=True, download=False, transform=train_transform)\n",
                "\n",
                "elif args.dataset == 'lsun':\n",
                "\n",
                "    train_transform = transforms.Compose([\n",
                "                    transforms.Resize(args.image_size),\n",
                "                    transforms.CenterCrop(args.image_size),\n",
                "                    transforms.RandomHorizontalFlip(),\n",
                "                    transforms.ToTensor(),\n",
                "                    transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5))\n",
                "                ])\n",
                "\n",
                "    train_data = LSUN(root='/datasets/LSUN/', classes=['church_outdoor_train'], transform=train_transform)\n",
                "    subset = list(range(0, 120000))\n",
                "    dataset = torch.utils.data.Subset(train_data, subset)\n",
                "\n",
                "\n",
                "elif args.dataset == 'celeba_256':\n",
                "    train_transform = transforms.Compose([\n",
                "            transforms.Resize(args.image_size),\n",
                "            transforms.RandomHorizontalFlip(),\n",
                "            transforms.ToTensor(),\n",
                "            transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5))\n",
                "        ])\n",
                "    dataset = LMDBDataset(root='/datasets/celeba-lmdb/', name='celeba', train=True, transform=train_transform)\n",
                "\n",
                "elif args.dataset == 'paired_colored_mnist':\n",
                "    # insert my dataset\n",
                "    dataset = load_paired_colored_mnist()\n",
                "\n",
                "\n",
                "train_sampler = torch.utils.data.distributed.DistributedSampler(dataset,\n",
                "                                                                num_replicas=args.world_size,\n",
                "                                                                rank=rank)\n",
                "data_loader = torch.utils.data.DataLoader(dataset,\n",
                "                                           batch_size=batch_size,\n",
                "                                           shuffle=False,\n",
                "                                           num_workers=4,\n",
                "                                           pin_memory=True,\n",
                "                                           sampler=train_sampler,\n",
                "                                           drop_last = True)\n",
                "def inverse_color(dataset_0, dataset_1):\n",
                "    return dataset_0, -dataset_1\n",
                "\n",
                "if args.plan == 'inv_color':\n",
                "    prior_dataset = load_prior_paired_colored_mnist(7, 3, inverse_color)\n",
                "else:\n",
                "    prior_dataset = load_prior_paired_colored_mnist(2, 3)\n",
                "prior_data_loader = torch.utils.data.DataLoader(prior_dataset,\n",
                "                                           batch_size=batch_size,\n",
                "                                           shuffle=False,\n",
                "                                           num_workers=4,\n",
                "                                           pin_memory=True,\n",
                "                                           sampler=train_sampler,\n",
                "                                           drop_last = True)\n",
                "\n",
                "netG = NCSNpp(args).to(device)\n",
                "\n",
                "\n",
                "#ddp\n",
                "# netG = nn.parallel.DistributedDataParallel(netG, device_ids=[gpu])\n",
                "# netD = nn.parallel.DistributedDataParallel(netD, device_ids=[gpu])\n",
                "\n",
                "\n",
                "exp = args.exp\n",
                "parent_dir = \"./saved_info/dd_gan/{}\".format(args.dataset)\n",
                "\n",
                "exp_path = os.path.join(parent_dir,exp)\n",
                "if rank == 0:\n",
                "    if not os.path.exists(exp_path):\n",
                "        os.makedirs(exp_path)\n",
                "        copy_source('BMGAN', exp_path)\n",
                "        shutil.copytree('score_sde/models', os.path.join(exp_path, 'score_sde/models'))\n",
                "\n",
                "\n",
                "coeff = Diffusion_Coefficients(args, device)\n",
                "\n",
                "if args.posterior == \"brownian_bridge\":\n",
                "    pos_coeff = BrownianPosterior_Coefficients(args, device)\n",
                "else:\n",
                "    raise ValueError('ONLY Brownian Bridge posterior')\n",
                "\n",
                "T = get_time_schedule(args, device)\n",
                "\n",
                "if args.resume:\n",
                "    checkpoint_file = os.path.join(exp_path, 'content.pth')\n",
                "    checkpoint = torch.load(checkpoint_file, map_location=device)\n",
                "    init_epoch = checkpoint['epoch']\n",
                "    epoch = init_epoch\n",
                "    netG.load_state_dict(checkpoint['netG_dict'])\n",
                "    # load G\n",
                "\n",
                "    optimizerG.load_state_dict(checkpoint['optimizerG'])\n",
                "    schedulerG.load_state_dict(checkpoint['schedulerG'])\n",
                "    # load D\n",
                "    netD.load_state_dict(checkpoint['netD_dict'])\n",
                "    optimizerD.load_state_dict(checkpoint['optimizerD'])\n",
                "    schedulerD.load_state_dict(checkpoint['schedulerD'])\n",
                "    global_step = checkpoint['global_step']\n",
                "    print(\"=> loaded checkpoint (epoch {})\"\n",
                "              .format(checkpoint['epoch']))\n",
                "else:\n",
                "    global_step, epoch, init_epoch = 0, 0, 0\n",
                "\n"
            ]
        },
        {
            "cell_type": "markdown",
            "id": "19069825",
            "metadata": {},
            "source": [
                "# Markovian Projection"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "id": "d47a188f",
            "metadata": {},
            "outputs": [],
            "source": [
                "import math\n",
                "\n",
                "class BasePrior:\n",
                "    def __init__(self):\n",
                "        pass\n",
                "    \n",
                "    @torch.no_grad()\n",
                "    def sample(self, x: torch.Tensor) -> torch.Tensor:\n",
                "        raise NotImplementedError()\n",
                "\n",
                "\n",
                "class WienerPrior(BasePrior):\n",
                "    def __init__(self, eps: float = 1):\n",
                "        self.eps = eps\n",
                "\n",
                "    @torch.no_grad()    \n",
                "    def sample(self, x: torch.Tensor) -> torch.Tensor:\n",
                "        return x + math.sqrt(self.eps) * torch.randn_like(x)\n",
                "    \n",
                "class CondLoaderSampler():\n",
                "    def __init__(self, loader, plan='ind', reverse=False):\n",
                "        self.loader = loader\n",
                "        self.reverse = reverse\n",
                "        self.plan = plan\n",
                "        self.it = iter(self.loader)\n",
                "        self.brown_prior = WienerPrior(args.epsilon)\n",
                "\n",
                "        \n",
                "    def sample(self, size=5):\n",
                "        assert size <= self.loader.batch_size\n",
                "        try:\n",
                "            if self.plan == 'ind' or self.plan == 'inv_color':\n",
                "                batch_x, batch_y = next(self.it)\n",
                "            elif self.plan == 'id':\n",
                "                batch_x, _ = next(self.it)\n",
                "                batch_y = batch_x\n",
                "            elif self.plan == 'aid':\n",
                "                batch_x, _ = next(self.it)\n",
                "                batch_y = -batch_x\n",
                "            elif self.plan == 'ind_id':\n",
                "                batch_x, _ = next(self.it)\n",
                "                batch_y, _ = next(self.it)\n",
                "            elif self.plan == 'ipf':\n",
                "                if self.reverse:\n",
                "                    _, batch_y = next(self.it)\n",
                "                    batch_x = self.brown_prior.sample(batch_y)\n",
                "                    batch_x = (batch_x - batch_x.min(dim=0)) / (batch_x.max(dim=0)- batch_x.min(dim=0))\n",
                "                else:\n",
                "                    batch_x, _ = next(self.it)\n",
                "                    batch_y = self.brown_prior.sample(batch_x)\n",
                "        except StopIteration:\n",
                "            self.it = iter(self.loader)\n",
                "            return self.sample(size)\n",
                "        except RuntimeError:\n",
                "            self.it = iter(self.loader)\n",
                "            return self.sample(size)\n",
                "        if len(batch_x) < size:\n",
                "            return self.sample(size)\n",
                "            \n",
                "        if self.reverse:\n",
                "            return batch_x[:size], batch_y[:size]\n",
                "        \n",
                "        return batch_y[:size], batch_x[:size]\n",
                "    \n",
                "class OTSampler():\n",
                "    def __init__(self, loader):\n",
                "        self.loader = loader\n",
                "        self.ot_plan_sampler = OTPlanSampler('exact')\n",
                "\n",
                "    def sample(self, size=5):\n",
                "        x, y = self.loader.sample(size=size)\n",
                "        return self.ot_plan_sampler.sample_plan(x, y)\n",
                "        \n",
                "\n",
                "class XSampler():\n",
                "    def __init__(self, sampler: CondLoaderSampler):\n",
                "        self.sampler = sampler\n",
                "        \n",
                "    def sample(self, size=5):\n",
                "        return self.sampler.sample(size)[0]\n",
                "    \n",
                "        \n",
                "class ModelCondSampler:\n",
                "    def __init__(self, sampler: XSampler, model_sample_fn, ema_g):\n",
                "        self.model_sample_fn = model_sample_fn\n",
                "        self.sampler = sampler\n",
                "        self.ema_g = ema_g\n",
                "        \n",
                "    def sample(self, size=5):\n",
                "        sample_x = self.sampler.sample(size)\n",
                "        \n",
                "        with self.ema_g.average_parameters():\n",
                "            sample_y = self.model_sample_fn(sample_x)\n",
                "        return sample_x, sample_y\n",
                "\n",
                "bmgan_sample_fn = lambda y: sample_from_model(pos_coeff, netG_proj, args.num_timesteps, y, T, args, return_trajectory=True)[0]\n"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "id": "3d1f806e",
            "metadata": {},
            "outputs": [],
            "source": [
                "\n",
                "def log_images(x, x_t_1, fake_samples, x_0_pred_list, sample_traj, postfix='fw'):\n",
                "    # from x_t_1 to x\n",
                "    \n",
                "    sample_img_path = os.path.join(exp_path, save_dir, 'pics', f'sample_discrete_M_{postfix}.png')\n",
                "    torchvision.utils.save_image(torch.cat([x[:20], x_t_1[:20], fake_samples[:20]], dim=2),\n",
                "                                 sample_img_path, normalize=True, nrow=20)\n",
                "    \n",
                "    sample_traj_path = os.path.join(exp_path, save_dir, 'pics', f'sample_traj_discrete_M_{postfix}.png')\n",
                "\n",
                "    torchvision.utils.save_image(torch.cat(sample_traj, dim=3)[:12],\n",
                "                                 sample_traj_path, normalize=True, nrow=1)\n",
                "    \n",
                "    sample_x_0_pred_path = os.path.join(exp_path, save_dir, 'pics', f'sample_x_0_pred_discrete_M_{postfix}.png')\n",
                "    \n",
                "    torchvision.utils.save_image(torch.cat([x_t_1] + x_0_pred_list +  [x], dim=3)[:12],\n",
                "                                 sample_x_0_pred_path, normalize=True, nrow=1)\n",
                "\n",
                "    sample_input_path = os.path.join(exp_path, save_dir, 'pics', f'input_{postfix}.png')\n",
                "    torchvision.utils.save_image(x_t_1[:16:4], sample_input_path, normalize=True, nrow=1)\n",
                "\n",
                "    samples_path = os.path.join(exp_path, save_dir, 'pics', f'samples_{postfix}.png')\n",
                "    torchvision.utils.save_image(fake_samples[:16], samples_path, normalize=True, nrow=4)\n",
                "    \n",
                "    if wandb.run:\n",
                "        wandb.log({f\"Sample_{postfix}\": wandb.Image(sample_img_path)})\n",
                "        wandb.log({f\"Sample_traj_{postfix}\": wandb.Image(sample_traj_path)})\n",
                "        wandb.log({f\"Sample_x0_pred_{postfix}\": wandb.Image(sample_x_0_pred_path)})\n",
                "        wandb.log({f\"Input {postfix}\": wandb.Image(sample_input_path)})\n",
                "        wandb.log({f\"Samples {postfix}\": wandb.Image(samples_path)})\n",
                "\n",
                "    \n",
                "def calculate_transport_cost(x_samples, y_samples):\n",
                "    return F.mse_loss(x_samples, y_samples)\n",
                "\n",
                "os.makedirs(os.path.join(exp_path, save_dir))\n",
                "\n",
                "os.makedirs(os.path.join(exp_path, save_dir, 'pics'))\n"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "id": "647c74f0",
            "metadata": {},
            "outputs": [],
            "source": [
                "    \n",
                "def markovian_projection(max_iter, condsampler, netG_proj, netD_proj,\n",
                "                         opt_G_proj, opt_D_proj,\n",
                "                         sch_G_proj, sch_D_proj, ema_g, D_opt_steps=5, fw_or_bw='fw'):\n",
                "    \n",
                "    for iteration in range(max_iter):\n",
                "        if prior:\n",
                "            condsampler_gt = condsampler\n",
                "        else:\n",
                "            if fw_or_bw == 'fw':\n",
                "                condsampler_gt = CondLoaderSampler(data_loader)\n",
                "            else:\n",
                "                condsampler_gt = CondLoaderSampler(data_loader, reverse=True)\n",
                "        \n",
                "        if rank == 0 and iteration % 1000 == 0:\n",
                "            \n",
                "            x, y = condsampler_gt.sample(args.batch_size)\n",
                "            \n",
                "            x, x_t_1 = x.to(device), y.to(device)\n",
                "            \n",
                "            \n",
                "            fake_sample, x_0_pred_list, trajectory = sample_from_model(pos_coeff, netG_proj, args.num_timesteps, x_t_1, T, args, return_trajectory=True)\n",
                "            # log to wandb\n",
                "            \n",
                "            log_images(x, x_t_1, fake_sample, x_0_pred_list, trajectory, postfix=fw_or_bw)\n",
                "            \n",
                "            for i in range(args.num_timesteps+1):\n",
                "                torchvision.utils.save_image(trajectory[i], os.path.join(exp_path, f\"gen_x_{args.num_timesteps-i}_M_{fw_or_bw}_{iteration}.png\"), normalize=True)\n",
                "\n",
                "            trajectory = q_sample_supervised_trajectory(pos_coeff, x, x_t_1)\n",
                "            for i in range(args.num_timesteps+1):\n",
                "                torchvision.utils.save_image(trajectory[i], os.path.join(exp_path, f\"x_{args.num_timesteps-i}_M_{fw_or_bw}_{iteration}.png.png\"), normalize=True)\n",
                "            \n",
                "            repeated_samples = torch.cat([pic.unsqueeze(0).repeat([4, 1, 1, 1]) for pic in x_t_1[:10]], dim=0)\n",
                "            \n",
                "            fake_sample, x_0_pred_list, trajectory = sample_from_model(pos_coeff, netG_proj,\n",
                "                                                                       args.num_timesteps, repeated_samples,\n",
                "                                                                       T, args, return_trajectory=True)\n",
                "            # log to wandb\n",
                "            \n",
                "            log_images(x[:repeated_samples.shape[0]], repeated_samples, fake_sample, x_0_pred_list, trajectory, postfix=fw_or_bw + '_repeated')\n",
                "        \n",
                "        \n",
                "        if rank == 0 and iteration % 1000 == 0:\n",
                "\n",
                "            with ema_g.average_parameters():\n",
                "\n",
                "                x, y = condsampler_gt.sample(args.batch_size)\n",
                "\n",
                "                x, x_t_1 = x.to(device), y.to(device)\n",
                "                \n",
                "                fake_sample, x_0_pred_list, trajectory = sample_from_model(pos_coeff, netG_proj, args.num_timesteps, x_t_1, T, args, return_trajectory=True)\n",
                "                # log to wandb\n",
                "                \n",
                "                log_images(x, x_t_1, fake_sample, x_0_pred_list, trajectory, postfix=fw_or_bw + '_ema')\n",
                "                \n",
                "                for i in range(args.num_timesteps+1):\n",
                "                    torchvision.utils.save_image(trajectory[i], os.path.join(exp_path, f\"gen_x_{args.num_timesteps-i}_M_{fw_or_bw}_{iteration}_ema.png\"), normalize=True)\n",
                "\n",
                "                trajectory = q_sample_supervised_trajectory(pos_coeff, x, x_t_1)\n",
                "                for i in range(args.num_timesteps+1):\n",
                "                    torchvision.utils.save_image(trajectory[i], os.path.join(exp_path, f\"x_{args.num_timesteps-i}_M_{fw_or_bw}_{iteration}_ema.png\"), normalize=True)\n",
                "                \n",
                "                repeated_samples = torch.cat([pic.unsqueeze(0).repeat([4, 1, 1, 1]) for pic in x_t_1[:10]], dim=0)\n",
                "                \n",
                "                fake_sample, x_0_pred_list, trajectory = sample_from_model(pos_coeff, netG_proj,\n",
                "                                                                        args.num_timesteps, repeated_samples,\n",
                "                                                                        T, args, return_trajectory=True)\n",
                "                # log to wandb\n",
                "                \n",
                "                log_images(x[:repeated_samples.shape[0]], repeated_samples, fake_sample, x_0_pred_list, trajectory, postfix=fw_or_bw + '_repeated' + '_ema')\n",
                "                \n",
                "        x, y = condsampler.sample(args.batch_size)\n",
                "        x, y = x.to(device), y.to(device)\n",
                "        \n",
                "        #-----Discriminator Opt Step-----\n",
                "        \n",
                "        # Get D ready for optimization\n",
                "        for p in netD_proj.parameters():  \n",
                "            p.requires_grad = True  \n",
                "        \n",
                "        netD_proj.zero_grad()\n",
                "\n",
                "        #sample from p(x_0)\n",
                "        real_data = x.to(device, non_blocking=True)\n",
                "        input_real_data = y.to(device, non_blocking=True)\n",
                "\n",
                "        if args.posterior == \"ddpm\":\n",
                "            t = torch.randint(0, args.num_timesteps, (1,), device=device).repeat(real_data.size(0))\n",
                "            x_t, x_tp1 = q_sample_supervised_pairs(pos_coeff, real_data, t, input_real_data)\n",
                "        elif args.posterior == \"brownian_bridge\":\n",
                "            t = torch.randint(0, args.num_timesteps, (real_data.size(0),), device=device)\n",
                "            x_t, x_tp1 = q_sample_supervised_pairs_brownian(pos_coeff, real_data, t, input_real_data)\n",
                "\n",
                "        x_t.requires_grad = True\n",
                "\n",
                "        # train with real\n",
                "        D_real = netD_proj(x_t, t, x_tp1.detach()).view(-1)\n",
                "\n",
                "        errD_real = F.softplus(-D_real)\n",
                "        errD_real = errD_real.mean()\n",
                "\n",
                "        errD_real.backward(retain_graph=True)\n",
                "\n",
                "        if args.lazy_reg is None:\n",
                "            grad_real = torch.autograd.grad(\n",
                "                        outputs=D_real.sum(), inputs=x_t, create_graph=True\n",
                "                        )[0]\n",
                "            grad_penalty = (\n",
                "                            grad_real.view(grad_real.size(0), -1).norm(2, dim=1) ** 2\n",
                "                            ).mean()\n",
                "\n",
                "\n",
                "            grad_penalty = args.r1_gamma / 2 * grad_penalty\n",
                "            grad_penalty.backward()\n",
                "        else:\n",
                "            if iteration % args.lazy_reg == 0:\n",
                "                grad_real = torch.autograd.grad(\n",
                "                        outputs=D_real.sum(), inputs=x_t, create_graph=True\n",
                "                        )[0]\n",
                "                grad_penalty = (\n",
                "                            grad_real.view(grad_real.size(0), -1).norm(2, dim=1) ** 2\n",
                "                            ).mean()\n",
                "\n",
                "\n",
                "                grad_penalty = args.r1_gamma / 2 * grad_penalty\n",
                "                grad_penalty.backward()\n",
                "        \n",
                "        # train with fake\n",
                "        latent_z = torch.randn(batch_size, nz, device=device)\n",
                "\n",
                "        x_0_predict = netG_proj(x_tp1.detach(), t, latent_z)\n",
                "        x_pos_sample = sample_posterior(pos_coeff, x_0_predict, x_tp1, t)\n",
                "        \n",
                "        output = netD_proj(x_pos_sample, t, x_tp1.detach()).view(-1)\n",
                "        \n",
                "        errD_fake = F.softplus(output)\n",
                "        errD_fake = errD_fake.mean()\n",
                "        errD_fake.backward()\n",
                "        \n",
                "        errD = errD_real + errD_fake\n",
                "        # Update D\n",
                "        opt_D_proj.step()\n",
                "        \n",
                "        \n",
                "        #-----Generator Opt Step-----\n",
                "        \n",
                "        if iteration % D_opt_steps == 0:\n",
                "        \n",
                "            # Get G ready for optimization\n",
                "            for p in netD_proj.parameters():\n",
                "                p.requires_grad = False\n",
                "            netG_proj.zero_grad()\n",
                "\n",
                "            if args.posterior == \"brownian_bridge\":\n",
                "                t = torch.randint(0, args.num_timesteps, (real_data.size(0),), device=device)\n",
                "                x_t, x_tp1 = q_sample_supervised_pairs_brownian(pos_coeff, real_data, t, input_real_data)\n",
                "            else:\n",
                "                raise ValueError('ONLY Brownian Bridge posterior')\n",
                "\n",
                "\n",
                "            latent_z = torch.randn(batch_size, nz,device=device)\n",
                "\n",
                "\n",
                "            x_0_predict = netG_proj(x_tp1.detach(), t, latent_z)\n",
                "            x_pos_sample = sample_posterior(pos_coeff, x_0_predict, x_tp1, t)\n",
                "\n",
                "            output = netD_proj(x_pos_sample, t, x_tp1.detach()).view(-1)\n",
                "\n",
                "\n",
                "            errG = F.softplus(-output)\n",
                "            errG = errG.mean()\n",
                "\n",
                "\n",
                "            errG.backward()\n",
                "            opt_G_proj.step()\n",
                "            ema_g.update()\n",
                "        \n",
                "            if wandb.run:\n",
                "                wandb.log({'LossG': errG.item(), 'LossD': errD.item()})\n",
                "\n",
                "        if iteration % 100 == 0:\n",
                "            if rank == 0:\n",
                "                print('Markovain proj {}: Iter {}, G Loss: {}, D Loss: {}'.format(fw_or_bw, iteration, errG.item(), errD.item()))\n",
                "\n",
                "#         # SCH iteration step\n",
                "#         if not args.no_lr_decay and iteration % 1000 == 0:\n",
                "\n",
                "#             sch_G_proj.step()\n",
                "#             sch_D_proj.step()\n",
                "\n",
                "    if rank == 0:\n",
                "        \n",
                "        if args.save_content:\n",
                "            if epoch % args.save_content_every == 0:\n",
                "                print('Saving content.')\n",
                "                content = {'epoch': epoch + 1, 'global_step': global_step, 'args': args,\n",
                "                           'netG_dict': netG_proj.state_dict(), 'optimizerG': opt_G_proj.state_dict(),\n",
                "                           'netD_dict': netD_proj.state_dict(),\n",
                "                           'optimizerD': opt_D_proj.state_dict()}\n",
                "                \n",
                "                torch.save(content, os.path.join(exp_path, save_dir, f'content_{fw_or_bw}.pth'))\n"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "id": "245c1a68",
            "metadata": {},
            "outputs": [],
            "source": [
                "\n",
                "wandb.init(project=\"BM_GAN\", name=exp_name, config=config)\n"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "id": "f86c7d4d-3da5-457e-be3f-0c1362b0482c",
            "metadata": {},
            "outputs": [],
            "source": [
                "\n",
                "coloring_seed = 42\n",
                "\n",
                "colored_mnist_2 = load_paired_colored_mnist_one_side(target_number=2, train=False,\n",
                "                                            seed=coloring_seed, dataset_size=10000)\n",
                "\n",
                "colored_mnist_3 = load_paired_colored_mnist_one_side(target_number=3, train=False,\n",
                "                                            seed=coloring_seed, dataset_size=10000)\n",
                "\n",
                "from torch.utils.data import DataLoader\n",
                "\n",
                "colored_mnist_2_loader = DataLoader(colored_mnist_2, batch_size=256)\n",
                "colored_mnist_3_loader = DataLoader(colored_mnist_3, batch_size=256)\n",
                "\n",
                "true_dataloader_bw = colored_mnist_3_loader\n",
                "\n",
                "model_input_dataloader_bw = colored_mnist_2_loader\n",
                "\n",
                "true_dataloader_fw = colored_mnist_2_loader\n",
                "\n",
                "model_input_dataloader_fw = colored_mnist_3_loader\n"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "id": "8c28aa51",
            "metadata": {
                "scrolled": true
            },
            "outputs": [],
            "source": [
                "\n",
                "netG_fw = NCSNpp(args).to(device)\n",
                "\n",
                "netD_fw = Discriminator_small(nc = 2*args.num_channels, ngf = args.ngf,\n",
                "                           t_emb_dim = args.t_emb_dim,\n",
                "                           act=nn.LeakyReLU(0.2)).to(device)\n",
                "\n",
                "optimizerG_fw = optim.Adam(netG_fw.parameters(),\n",
                "                           lr=args.lr_g, betas = (args.beta1, args.beta2))\n",
                "optimizerD_fw = optim.Adam(netD_fw.parameters(),\n",
                "                           lr=args.lr_d, betas = (args.beta1, args.beta2))\n",
                "\n",
                "schedulerG_fw = None\n",
                "schedulerD_fw = None\n",
                "\n",
                "# schedulerG_fw = torch.optim.lr_scheduler.CosineAnnealingLR(optimizerG_fw, 100, eta_min=1e-5)\n",
                "# schedulerD_fw = torch.optim.lr_scheduler.CosineAnnealingLR(optimizerD_fw, 100, eta_min=1e-5)\n",
                "\n",
                "ema_g_fw = ExponentialMovingAverage(netG_fw.parameters(), decay=args.ema_decay)\n",
                "ema_g_fw.to(device)\n"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "id": "322dfe4e",
            "metadata": {},
            "outputs": [],
            "source": [
                "\n",
                "netG_bw = NCSNpp(args).to(device)\n",
                "\n",
                "netD_bw = Discriminator_small(nc = 2*args.num_channels, ngf = args.ngf,\n",
                "                           t_emb_dim = args.t_emb_dim,\n",
                "                           act=nn.LeakyReLU(0.2)).to(device)\n",
                "\n",
                "optimizerG_bw = optim.Adam(netG_bw.parameters(),\n",
                "                           lr=args.lr_g, betas = (args.beta1, args.beta2))\n",
                "optimizerD_bw = optim.Adam(netD_bw.parameters(),\n",
                "                           lr=args.lr_d, betas = (args.beta1, args.beta2))\n",
                "\n",
                "schedulerG_bw = None\n",
                "schedulerD_bw = None\n",
                "\n",
                "# schedulerG_bw = torch.optim.lr_scheduler.CosineAnnealingLR(optimizerG_bw, 100, eta_min=1e-5)\n",
                "# schedulerD_bw = torch.optim.lr_scheduler.CosineAnnealingLR(optimizerD_bw, 100, eta_min=1e-5)\n",
                "\n",
                "ema_g_bw = ExponentialMovingAverage(netG_bw.parameters(), decay=args.ema_decay)\n",
                "ema_g_bw.to(device)\n",
                "\n",
                "#     optimizerG_bw = EMA(optimizerG_fw, ema_decay=args.ema_decay)\n"
            ]
        },
        {
            "cell_type": "markdown",
            "id": "dd307cb0",
            "metadata": {},
            "source": [
                "## ipmf"
            ]
        },
        {
            "cell_type": "markdown",
            "id": "40f77380",
            "metadata": {},
            "source": [
                "### First forward iteration"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "id": "d194fa7c",
            "metadata": {},
            "outputs": [],
            "source": [
                "prior = True\n",
                "condsampler_fw = CondLoaderSampler(prior_data_loader, args.plan)\n",
                "\n",
                "# if mini_batch_OT:\n",
                "#     condsampler_fw = OTSampler(condsampler_fw)\n",
                "\n",
                "markovian_projection(markovian_proj_iters, condsampler_fw, netG_fw, netD_fw,\n",
                "                         optimizerG_fw, optimizerD_fw,\n",
                "                         schedulerG_fw, schedulerD_fw, ema_g_fw, D_opt_steps=D_opt_steps, fw_or_bw='fw')\n",
                "\n",
                "x, y = condsampler_fw.sample(args.batch_size)\n",
                "\n",
                "x, x_t_1 = x.to(device), y.to(device)\n",
                "\n",
                "fake_sample, x_0_pred_list, trajectory = sample_from_model(pos_coeff, netG_fw,\n",
                "                                            args.num_timesteps, x_t_1, T,\n",
                "                                            args, return_trajectory=True)\n",
                "\n",
                "log_images(x, x_t_1, fake_sample, x_0_pred_list, trajectory, postfix='fw_after_M')\n",
                "\n",
                "transport_cost_model = calculate_transport_cost(x_t_1, fake_sample)\n",
                "transport_cost_ind = calculate_transport_cost(x_t_1, x)\n",
                "\n",
                "# FID\n",
                "sample_fn_fw = lambda y: sample_from_model(pos_coeff, netG_fw,\n",
                "                                            args.num_timesteps, y, T,\n",
                "                                            args, return_trajectory=True)[0]\n",
                "\n",
                "fid, ot_cost = compute_fid_and_ot_cost(true_dataloader_fw, model_input_dataloader_fw, sample_fn_fw)\n",
                "\n",
                "with ema_g_fw.average_parameters():\n",
                "    \n",
                "    sample_fn_fw = lambda y: sample_from_model(pos_coeff, netG_fw,\n",
                "                                                args.num_timesteps, y, T,\n",
                "                                                args, return_trajectory=True)[0]\n",
                "    \n",
                "    fid_ema, ot_cost_ema = compute_fid_and_ot_cost(true_dataloader_fw, model_input_dataloader_fw, sample_fn_fw)\n",
                "\n",
                "if wandb.run:\n",
                "    wandb.log({'T_cost_fw': transport_cost_model, 'T_cost_ind': transport_cost_ind,\n",
                "               'fid_fw': fid, 'fid_fw_ema': fid_ema,\n",
                "               'ot_cost_fw': ot_cost, 'ot_cost_fw_ema': ot_cost_ema})\n",
                "\n",
                "torch.save(netG_fw.state_dict(), os.path.join(exp_path, save_dir, f'netG_fw_0.pth'))\n",
                "\n",
                "with ema_g_fw.average_parameters():\n",
                "    torch.save(netG_fw.state_dict(), os.path.join(exp_path, save_dir, f'netG_fw_0_ema.pth'))\n"
            ]
        },
        {
            "cell_type": "markdown",
            "id": "276ca070",
            "metadata": {},
            "source": [
                "### First backward iteration"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "id": "cf237ac0",
            "metadata": {},
            "outputs": [],
            "source": [
                "prior = True\n",
                "condsampler_bw = CondLoaderSampler(prior_data_loader, args.plan, reverse=True)\n",
                "\n",
                "# if mini_batch_OT:\n",
                "#     condsampler_bw = OTSampler(condsampler_bw)\n",
                "\n",
                "markovian_projection(markovian_proj_iters, condsampler_bw, netG_bw, netD_bw,\n",
                "                         optimizerG_bw, optimizerD_bw,\n",
                "                         schedulerG_bw, schedulerD_bw, ema_g_bw, D_opt_steps=D_opt_steps, fw_or_bw='bw')\n",
                "\n",
                "x, y = condsampler_bw.sample(args.batch_size)\n",
                "\n",
                "x, x_t_1 = x.to(device), y.to(device)\n",
                "\n",
                "fake_sample, x_0_pred_list, trajectory = sample_from_model(pos_coeff, netG_bw,\n",
                "                                            args.num_timesteps, x_t_1, T,\n",
                "                                            args, return_trajectory=True)\n",
                "\n",
                "log_images(x, x_t_1, fake_sample, x_0_pred_list, trajectory, postfix='bw_after_M')\n",
                "\n",
                "transport_cost_model = calculate_transport_cost(x_t_1, fake_sample)\n",
                "transport_cost_ind = calculate_transport_cost(x_t_1, x)\n",
                "\n",
                "# FID\n",
                "sample_fn_bw = lambda y: sample_from_model(pos_coeff, netG_bw,\n",
                "                                            args.num_timesteps, y, T,\n",
                "                                            args, return_trajectory=True)[0]\n",
                "\n",
                "fid, ot_cost = compute_fid_and_ot_cost(true_dataloader_bw, model_input_dataloader_bw, sample_fn_bw)\n",
                "\n",
                "with ema_g_bw.average_parameters():\n",
                "\n",
                "    sample_fn_bw = lambda y: sample_from_model(pos_coeff, netG_bw,\n",
                "                                                args.num_timesteps, y, T,\n",
                "                                                args, return_trajectory=True)[0]\n",
                "\n",
                "    fid_ema, ot_cost_ema = compute_fid_and_ot_cost(true_dataloader_bw, model_input_dataloader_bw, sample_fn_bw)\n",
                "\n",
                "if wandb.run:\n",
                "    wandb.log({'T_cost_bw': transport_cost_model, 'T_cost_ind': transport_cost_ind,\n",
                "               'fid_bw': fid, 'fid_bw_ema': fid_ema,\n",
                "               'ot_cost_bw': ot_cost, 'ot_cost_bw_ema': ot_cost_ema})\n",
                "\n",
                "\n",
                "torch.save(netG_bw.state_dict(), os.path.join(exp_path, save_dir, f'netG_bw_0.pth'))\n",
                "\n",
                "with ema_g_bw.average_parameters():\n",
                "    torch.save(netG_bw.state_dict(), os.path.join(exp_path, save_dir, f'netG_bw_0_ema.pth'))\n",
                "    \n"
            ]
        },
        {
            "cell_type": "markdown",
            "id": "d78a64d4",
            "metadata": {},
            "source": [
                "## Iteration"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "id": "353fd012",
            "metadata": {},
            "outputs": [],
            "source": [
                "prior = False\n",
                "for i in range(1, ipmf_iters + 1):\n",
                "    \n",
                "    # ----Forward model learning (y -> x)----\n",
                "    \n",
                "    if ema_start_ipmf:\n",
                "        with ema_g_fw.average_parameters():\n",
                "            \n",
                "            netG_fw_clone = NCSNpp(args).to(device)\n",
                "            \n",
                "            netG_fw_clone.load_state_dict(copy.deepcopy(netG_fw.state_dict()))\n",
                "\n",
                "            netG_fw = netG_fw_clone\n",
                "\n",
                "            ema_g_fw = ExponentialMovingAverage(netG_fw.parameters(), decay=args.ema_decay)\n",
                "            \n",
                "            optimizerG_fw = optim.Adam(netG_fw.parameters(),\n",
                "                                       lr=args.lr_g, betas = (args.beta1, args.beta2))\n",
                "    \n",
                "    bmgan_sample_bw = lambda x: sample_from_model(pos_coeff, netG_bw,\n",
                "                                                  args.num_timesteps, x.to(device),\n",
                "                                                  T, args, return_trajectory=True)[0]\n",
                "    \n",
                "    sampler_x = XSampler(CondLoaderSampler(data_loader))\n",
                "    \n",
                "    condsampler_fw = ModelCondSampler(sampler_x, bmgan_sample_bw, ema_g_bw)\n",
                "\n",
                "    markovian_projection(inner_ipmf_mark_proj_iters, condsampler_fw, netG_fw, netD_fw,\n",
                "                             optimizerG_fw, optimizerD_fw,\n",
                "                             schedulerG_fw, schedulerD_fw, ema_g_fw, D_opt_steps=D_opt_steps, fw_or_bw='fw')\n",
                "\n",
                "    x, y = condsampler_fw.sample(args.batch_size)\n",
                "\n",
                "    x, x_t_1 = x.to(device), y.to(device)\n",
                "\n",
                "    fake_sample, x_0_pred_list, trajectory = sample_from_model(pos_coeff, netG_fw,\n",
                "                                                args.num_timesteps, x_t_1, T,\n",
                "                                                args, return_trajectory=True)\n",
                "    \n",
                "    log_images(x, x_t_1, fake_sample, x_0_pred_list, trajectory, postfix='fw_after_M')\n",
                "\n",
                "    transport_cost_model = calculate_transport_cost(x_t_1, fake_sample)\n",
                "    transport_cost_ind = calculate_transport_cost(x_t_1, x)\n",
                "    \n",
                "    # FID\n",
                "    sample_fn_fw = lambda y: sample_from_model(pos_coeff, netG_fw,\n",
                "                                                args.num_timesteps, y, T,\n",
                "                                                args, return_trajectory=True)[0]\n",
                "    \n",
                "    fid, ot_cost = compute_fid_and_ot_cost(true_dataloader_fw, model_input_dataloader_fw, sample_fn_fw)\n",
                "    \n",
                "    with ema_g_fw.average_parameters():\n",
                "        \n",
                "        sample_fn_fw = lambda y: sample_from_model(pos_coeff, netG_fw,\n",
                "                                                    args.num_timesteps, y, T,\n",
                "                                                    args, return_trajectory=True)[0]\n",
                "        \n",
                "        fid_ema, ot_cost_ema = compute_fid_and_ot_cost(true_dataloader_fw, model_input_dataloader_fw, sample_fn_fw)\n",
                "    \n",
                "    if wandb.run:\n",
                "        wandb.log({'T_cost_fw': transport_cost_model, 'T_cost_ind': transport_cost_ind,\n",
                "                   'fid_fw': fid, 'fid_fw_ema': fid_ema,\n",
                "                   'ot_cost_fw': ot_cost, 'ot_cost_fw_ema': ot_cost_ema})\n",
                "        \n",
                "    torch.save(netG_fw.state_dict(), os.path.join(exp_path, save_dir, f'netG_fw_{i}.pth'))\n",
                "    \n",
                "    with ema_g_fw.average_parameters():\n",
                "        torch.save(netG_fw.state_dict(), os.path.join(exp_path, save_dir, f'netG_fw_{i}_ema.pth'))\n",
                "    \n",
                "    # ----Backward model learning (x -> y)----\n",
                "    \n",
                "    if ema_start_ipmf:\n",
                "        with ema_g_bw.average_parameters():\n",
                "            \n",
                "            netG_bw_clone = NCSNpp(args).to(device)\n",
                "            \n",
                "            netG_bw_clone.load_state_dict(copy.deepcopy(netG_bw.state_dict()))\n",
                "\n",
                "            netG_bw = netG_bw_clone\n",
                "            \n",
                "            ema_g_bw = ExponentialMovingAverage(netG_bw.parameters(), decay=args.ema_decay)\n",
                "            \n",
                "            optimizerG_bw = optim.Adam(netG_bw.parameters(),\n",
                "                                       lr=args.lr_g, betas = (args.beta1, args.beta2))\n",
                "    \n",
                "    bmgan_sample_fw = lambda x: sample_from_model(pos_coeff, netG_fw,\n",
                "                                                  args.num_timesteps, x.to(device),\n",
                "                                                  T, args, return_trajectory=True)[0]\n",
                "    \n",
                "    sampler_x = XSampler(CondLoaderSampler(data_loader, reverse=True))\n",
                "    \n",
                "    condsampler_bw = ModelCondSampler(sampler_x, bmgan_sample_fw, ema_g_fw)\n",
                "    \n",
                "    markovian_projection(inner_ipmf_mark_proj_iters, condsampler_bw, netG_bw, netD_bw,\n",
                "                             optimizerG_bw, optimizerD_bw,\n",
                "                             schedulerG_bw, schedulerD_bw, ema_g_bw, D_opt_steps=D_opt_steps, fw_or_bw='bw')\n",
                "    \n",
                "    x, y = condsampler_bw.sample(args.batch_size)\n",
                "    \n",
                "    x, x_t_1 = x.to(device), y.to(device)\n",
                "    \n",
                "    fake_sample, x_0_pred_list, trajectory = sample_from_model(pos_coeff, netG_bw,\n",
                "                                                args.num_timesteps, x_t_1, T,\n",
                "                                                args, return_trajectory=True)\n",
                "       \n",
                "    log_images(x, x_t_1, fake_sample, x_0_pred_list, trajectory, postfix='bw_after_M')\n",
                "    \n",
                "    transport_cost_model = calculate_transport_cost(x_t_1, fake_sample)\n",
                "    transport_cost_ind = calculate_transport_cost(x_t_1, x)\n",
                "    \n",
                "    # FID\n",
                "    \n",
                "    sample_fn_bw = lambda y: sample_from_model(pos_coeff, netG_bw,\n",
                "                                                args.num_timesteps, y, T,\n",
                "                                                args, return_trajectory=True)[0]\n",
                "    \n",
                "    fid, ot_cost = compute_fid_and_ot_cost(true_dataloader_bw, model_input_dataloader_bw, sample_fn_bw)\n",
                "    \n",
                "    with ema_g_bw.average_parameters():\n",
                "    \n",
                "        sample_fn_bw = lambda y: sample_from_model(pos_coeff, netG_bw,\n",
                "                                                    args.num_timesteps, y, T,\n",
                "                                                    args, return_trajectory=True)[0]\n",
                "    \n",
                "        fid_ema, ot_cost_ema = compute_fid_and_ot_cost(true_dataloader_bw, model_input_dataloader_bw, sample_fn_bw)\n",
                "\n",
                "    if wandb.run:\n",
                "        wandb.log({'T_cost_bw': transport_cost_model, 'T_cost_ind': transport_cost_ind,\n",
                "                   'fid_bw': fid, 'fid_bw_ema': fid_ema,\n",
                "                   'ot_cost_bw': ot_cost, 'ot_cost_bw_ema': ot_cost_ema})\n",
                "        \n",
                "    \n",
                "    torch.save(netG_bw.state_dict(), os.path.join(exp_path, save_dir, f'netG_bw_{i}.pth'))\n",
                "    \n",
                "    with ema_g_bw.average_parameters():\n",
                "        torch.save(netG_bw.state_dict(), os.path.join(exp_path, save_dir, f'netG_bw_{i}_ema.pth'))\n",
                "        \n"
            ]
        }
    ],
    "metadata": {
        "celltoolbar": "Tags",
        "kernelspec": {
            "display_name": "Python 3 (ipykernel)",
            "language": "python",
            "name": "python3"
        },
        "language_info": {
            "codemirror_mode": {
                "name": "ipython",
                "version": 3
            },
            "file_extension": ".py",
            "mimetype": "text/x-python",
            "name": "python",
            "nbconvert_exporter": "python",
            "pygments_lexer": "ipython3",
            "version": "3.10.14"
        }
    },
    "nbformat": 4,
    "nbformat_minor": 5
}
