{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import torch\n",
    "import networkx as nx\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import pickle\n",
    "from easydict import EasyDict as edict\n",
    "\n",
    "from losses import get_score_fn\n",
    "from solver_guidance import Predictor, LangevinCorrector, NoneCorrector\n",
    "from utils.graph_utils import mask_adjs, mask_x, gen_noise\n",
    "from tqdm.notebook import trange\n",
    "\n",
    "from parsers.config import get_config\n",
    "\n",
    "from utils.logger import Logger, set_log, start_log, train_log, sample_log, check_log\n",
    "from utils.loader import load_ckpt, load_data, load_seed, load_device, load_model_from_ckpt, \\\n",
    "                         load_ema_from_ckpt, load_sde, load_yaml_config\n",
    "from utils.graph_utils import adjs_to_graphs, init_flags, quantize_mol, \\\n",
    "                              compute_group_assignments, count_across_community_edges, est_p_intra_inter, \\\n",
    "                              is_sbm_graph\n",
    "from utils.mol_utils import gen_mol, mols_to_smiles, load_smiles, canonicalize_smiles, mols_to_nx\n",
    "from moses.metrics.metrics import get_all_metrics\n",
    "from evaluation.stats import eval_graph_list\n",
    "\n",
    "from losses_guidance import compute_dp1, compute_dp2, compute_nodedp1, compute_nodedp2"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Data and preamble"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "device = load_device()\n",
    "device = [0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "config_file = 'sample_qm9' # sample_sbm sample_community_small\n",
    "seed = 0\n",
    "config = get_config(config_file, seed)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# -------- Load checkpoint --------\n",
    "ckpt_dict = load_ckpt(config, device)\n",
    "configt = ckpt_dict['config']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "load_seed(configt.seed)\n",
    "train_graph_list, _ = load_data(configt, get_graph_list=True)\n",
    "with open(f'data/{configt.data.data.lower()}_test_nx.pkl', 'rb') as f:\n",
    "    test_graph_list = pickle.load(f)                                   # for NSPDK MMD"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_smiles, test_smiles = load_smiles(configt.data.data)\n",
    "train_smiles, test_smiles = canonicalize_smiles(train_smiles), canonicalize_smiles(test_smiles)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "log_folder_name, log_dir, _ = set_log(configt, is_train=False)\n",
    "log_name = f\"{config.ckpt}-sample-guidance\"\n",
    "logger = Logger(str(os.path.join(log_dir, f'{log_name}.log')), mode='a')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "f, axs = plt.subplots(5,5, figsize=(10, 10))\n",
    "for ax in axs.ravel():\n",
    "    nx.draw(test_graph_list[0], with_labels=False, ax=ax, node_size=10)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if not check_log(log_folder_name, log_name):\n",
    "    logger.log(f'{log_name}')\n",
    "    start_log(logger, configt)\n",
    "    train_log(logger, configt)\n",
    "sample_log(logger, config)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# -------- Load models --------\n",
    "model_x = load_model_from_ckpt(ckpt_dict['params_x'], ckpt_dict['x_state_dict'], device)\n",
    "model_adj = load_model_from_ckpt(ckpt_dict['params_adj'], ckpt_dict['adj_state_dict'], device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if config.sample.use_ema:\n",
    "    ema_x = load_ema_from_ckpt(model_x, ckpt_dict['ema_x'], configt.train.ema)\n",
    "    ema_adj = load_ema_from_ckpt(model_adj, ckpt_dict['ema_adj'], configt.train.ema)\n",
    "    \n",
    "    ema_x.copy_to(model_x.parameters())\n",
    "    ema_adj.copy_to(model_adj.parameters())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(f'GEN SEED: {config.sample.seed}')\n",
    "load_seed(config.sample.seed)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Generation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sde_x = load_sde(configt.sde.x)\n",
    "sde_adj = load_sde(configt.sde.adj)\n",
    "max_node_num  = configt.data.max_node_num\n",
    "\n",
    "device_id = f'cuda:{device[0]}' if isinstance(device, list) else device\n",
    "print(device_id)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "continuous = True\n",
    "\n",
    "predictor = config.sampler.predictor\n",
    "corrector = config.sampler.corrector\n",
    "\n",
    "snr=config.sampler.snr\n",
    "scale_eps=config.sampler.scale_eps\n",
    "n_steps=config.sampler.n_steps\n",
    "\n",
    "probability_flow = config.sample.probability_flow\n",
    "eps = config.sample.eps\n",
    "denoise = config.sample.noise_removal"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if configt.data.data in ['QM9', 'ZINC250k']: # CHANGED\n",
    "    batch_size = 10000\n",
    "    # batch_size = configt.data.batch_size\n",
    "else:\n",
    "    batch_size = configt.data.batch_size\n",
    "shape_x = (batch_size, max_node_num, configt.data.max_feat_num)\n",
    "shape_adj = (batch_size, max_node_num, max_node_num)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "init_flags_iter = init_flags(train_graph_list, configt, batch_size=batch_size).to(device_id)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "init_x = sde_x.prior_sampling(shape_x).to(device_id)\n",
    "init_adj = sde_adj.prior_sampling_sym(shape_adj).to(device_id)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class ReverseDiffusionPredictor(Predictor):\n",
    "    def __init__(self, obj, sde, score_fn, probability_flow=False, guidance_args=None):\n",
    "        super().__init__(sde, score_fn, probability_flow)\n",
    "        self.obj = obj\n",
    "\n",
    "        self.guidance_args = guidance_args\n",
    "\n",
    "        self.Z = None\n",
    "\n",
    "    def guidance(self, x, adj, flags, t, is_adj):\n",
    "        obj = adj if is_adj else x\n",
    "\n",
    "        dt = -1. / self.rsde.N\n",
    "        timestep = (t * (self.rsde.N - 1) / self.rsde.T).long()\n",
    "\n",
    "        loss_fn = eval(self.guidance_args.loss_fn)\n",
    "        loss_kwargs = self.guidance_args.get('loss_kwargs', {})\n",
    "\n",
    "        if self.obj == 'adj' and self.Z is None:\n",
    "            # Assign half and half of the nodes randomly to each group\n",
    "            n_elems = adj.shape[1] // 2\n",
    "            # Create a template row with the correct number of elements\n",
    "            template_row = np.array([0] * n_elems + [1] * (adj.shape[1] - n_elems))\n",
    "\n",
    "            # Create an array where each row is a copy of the template row\n",
    "            idxs_com = np.tile(template_row, (adj.shape[0], 1))\n",
    "\n",
    "            # Apply a random permutation along the columns for each row\n",
    "            for i in range(adj.shape[0]):\n",
    "                np.random.shuffle(idxs_com[i])\n",
    "            Zs = torch.nn.functional.one_hot(torch.tensor(idxs_com), num_classes=self.guidance_args.get('n_com', 2)).float().to(adj.device)\n",
    "            self.Z = Zs.permute(0,2,1)\n",
    "        loss_kwargs['Z'] = self.Z.clone()\n",
    "        \n",
    "        if self.guidance_args.method == 'greedy':\n",
    "            n_traj = self.guidance_args.n_traj\n",
    "\n",
    "            f, G = self.rsde.discretize(x, adj, flags, t, is_adj=is_adj)\n",
    "            \n",
    "            obj_mean = obj - f\n",
    "\n",
    "            losses = torch.zeros(n_traj, obj.shape[0])\n",
    "            obj_hats = []\n",
    "            for i in range(n_traj):\n",
    "                z = gen_noise(obj, flags, sym=is_adj)\n",
    "                obj_hat = obj_mean.clone() + G[:, None, None] * z\n",
    "                obj_hats.append(obj_hat)\n",
    "                \n",
    "                score_i = self.score_fn(x, obj_hat, flags, t+dt) if is_adj else self.score_fn(obj_hat, adj, flags, t+dt)\n",
    "\n",
    "                obj0hat = self.sde.obj0estimation(obj_hat, score_i, timestep)\n",
    "                obj0hat_masked = mask_adjs(obj0hat, flags) if is_adj else mask_x(obj0hat, flags)\n",
    "                \n",
    "                losses[i,:] = loss_fn(obj0hat_masked, **loss_kwargs)\n",
    "\n",
    "            losses_expanded = torch.argmin(losses, dim=0).view(1, obj.shape[0], 1, 1).expand(1, obj.shape[0], obj.shape[1], obj.shape[2]).to(obj.device)\n",
    "\n",
    "            return torch.gather(torch.stack(obj_hats, dim=0), 0, losses_expanded).squeeze(0), obj_mean\n",
    "\n",
    "        elif self.guidance_args.method == 'zero':\n",
    "            n_traj = self.guidance_args.n_traj\n",
    "\n",
    "            f, G = self.rsde.discretize(x, adj, flags, t, is_adj=is_adj)\n",
    "            \n",
    "            obj_mean = obj - f\n",
    "\n",
    "            score_no_noise = self.score_fn(x, obj_mean.clone(), flags, t) if is_adj else self.score_fn(obj.clone(), adj, flags, t)\n",
    "            obj0hat_no_noise = self.sde.obj0estimation(obj_mean.clone(), score_no_noise, timestep)\n",
    "            obj0hat_masked_no_noise = mask_adjs(obj0hat_no_noise, flags) if is_adj else mask_x(obj0hat_no_noise, flags)\n",
    "            no_noise_loss = loss_fn(obj0hat_masked_no_noise, **loss_kwargs)\n",
    "\n",
    "            losses = torch.zeros(n_traj, obj.shape[0], device=obj.device)\n",
    "            noise_directions = []\n",
    "            for i in range(n_traj):\n",
    "                z = gen_noise(obj, flags, sym=is_adj)\n",
    "                obj_hat = obj_mean.clone() + G[:, None, None] * z\n",
    "                noise_directions.append(z)\n",
    "                \n",
    "                score_i = self.score_fn(x, obj_hat, flags, t+dt) if is_adj else self.score_fn(obj_hat, adj, flags, t+dt)\n",
    "\n",
    "                obj0hat = self.sde.obj0estimation(obj_hat, score_i, timestep)\n",
    "                obj0hat_masked = mask_adjs(obj0hat, flags) if is_adj else mask_x(obj0hat, flags)\n",
    "                \n",
    "                losses[i,:] = loss_fn(obj0hat_masked, **loss_kwargs)\n",
    "\n",
    "            weights = (losses - no_noise_loss[None,:]) / G[None, :]\n",
    "            directions = torch.stack(noise_directions, dim=0)\n",
    "            obj = obj_mean - self.guidance_args.lr_zero * (weights[:,:,None,None] * directions).mean(dim=0)\n",
    "            return obj, obj_mean\n",
    "\n",
    "        \n",
    "        elif self.guidance_args.method == 'loss':\n",
    "\n",
    "            with torch.enable_grad():\n",
    "\n",
    "                obj.requires_grad = True\n",
    "\n",
    "                score = self.score_fn(x, obj, flags, t) if is_adj else self.score_fn(obj, adj, flags, t)\n",
    "\n",
    "                obj0hat = self.sde.obj0estimation(obj, score, timestep)\n",
    "                obj0hat_masked = mask_adjs(obj0hat, flags) if is_adj else mask_x(obj0hat, flags)\n",
    "\n",
    "                loss = loss_fn(obj0hat_masked, **loss_kwargs).mean()\n",
    "\n",
    "                loss.backward()\n",
    "\n",
    "                obj_grad = obj.grad.detach().clone()\n",
    "                obj.grad = None\n",
    "\n",
    "            f, G = self.rsde.discretize(x, adj, flags, t, is_adj=is_adj)\n",
    "\n",
    "            obj_mean = obj - f\n",
    "\n",
    "            z = gen_noise(obj, flags, sym=is_adj)\n",
    "            obj = obj_mean + G[:, None, None] * z\n",
    "            \n",
    "            if self.guidance_args.lr_guidance_method == 'adaptive':\n",
    "                obj -= self.guidance_args.lr_guidance / torch.abs(loss) * obj_grad\n",
    "            else:\n",
    "                obj -= self.guidance_args.lr_guidance * obj_grad\n",
    "\n",
    "            return obj, obj_mean\n",
    "        else:\n",
    "            raise NotImplementedError(f\"guidance method {self.guidance_args.method} not yet supported.\")\n",
    "\n",
    "\n",
    "    def update_fn(self, x, adj, flags, t):\n",
    "        timestep = (t[0] * (self.rsde.N - 1) / self.rsde.T).long()\n",
    "\n",
    "        var = x if self.obj == 'x' else adj\n",
    "\n",
    "        if self.guidance_args is not None and \\\n",
    "                timestep > 0:\n",
    "            var, var_mean = self.guidance(x, adj, flags, t, is_adj=self.obj == 'adj')\n",
    "        else:\n",
    "            f, G = self.rsde.discretize(x, adj, flags, t, is_adj=self.obj == 'adj')\n",
    "            z = gen_noise(var, flags, sym=self.obj == 'adj')\n",
    "            var_mean = var - f\n",
    "            var = var_mean + G[:, None, None] * z\n",
    "        return var, var_mean\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Repeated here for cleaner code\n",
    "class EulerMaruyamaPredictor(Predictor):\n",
    "    def __init__(self, obj, sde, score_fn, probability_flow=False, guidance_args=None):\n",
    "        super().__init__(sde, score_fn, probability_flow)\n",
    "        self.obj = obj\n",
    "\n",
    "        self.guidance_args = guidance_args\n",
    "\n",
    "        self.Z = None\n",
    "\n",
    "    def guidance(self, x, adj, flags, t, is_adj):\n",
    "        obj = adj if is_adj else x\n",
    "\n",
    "        dt = -1. / self.rsde.N\n",
    "        timestep = (t * (self.rsde.N - 1) / self.rsde.T).long()\n",
    "\n",
    "        loss_fn = eval(self.guidance_args.loss_fn)\n",
    "        loss_kwargs = self.guidance_args.get('loss_kwargs', {})\n",
    "\n",
    "        if self.obj == 'adj' and self.Z is None:\n",
    "            if self.guidance_args.method_Z == \"communities\":\n",
    "                adj0 = self.sde.obj0estimation(adj, self.score_fn(x, adj, flags, t), timestep)\n",
    "                adj0 = mask_adjs(adj0, flags)\n",
    "                gs = adjs_to_graphs(quantize(adj0), True)\n",
    "                Zs = torch.zeros((adj.shape[0], adj.shape[1], self.guidance_args.get('n_com', 2))).to(adj.device)\n",
    "                for g, graph in enumerate(gs):\n",
    "                    Zs[g,:graph.number_of_nodes(),:] = compute_group_assignments(graph, self.guidance_args.get('n_com', 2))\n",
    "                self.Z = Zs.permute(0,2,1)\n",
    "            elif self.guidance_args.method_Z == \"random\":\n",
    "                idxs_com = np.random.randint(0, self.guidance_args.get('n_com', 2), size=(adj.shape[0], adj.shape[1]))\n",
    "                Zs = torch.nn.functional.one_hot(torch.tensor(idxs_com), num_classes=self.guidance_args.get('n_com', 2)).float().to(adj.device)\n",
    "                self.Z = Zs.permute(0,2,1)\n",
    "        loss_kwargs['Z'] = self.Z.clone()\n",
    "        \n",
    "        if self.guidance_args.method == 'greedy':\n",
    "            n_traj = self.guidance_args.n_traj\n",
    "\n",
    "            drift, diffusion = self.rsde.sde(x, adj, flags, t, is_adj=is_adj)\n",
    "            \n",
    "            obj_mean = obj + drift * dt\n",
    "\n",
    "            losses = torch.zeros(n_traj, obj.shape[0])\n",
    "            obj_hats = []\n",
    "            for i in range(n_traj):\n",
    "                z = gen_noise(obj, flags, sym=is_adj)\n",
    "                obj_hat = obj_mean.clone() + diffusion[:, None, None] * np.sqrt(-dt) * z\n",
    "                obj_hats.append(obj_hat)\n",
    "                \n",
    "                score_i = self.score_fn(x, obj_hat, flags, t+dt) if is_adj else self.score_fn(obj_hat, adj, flags, t+dt)\n",
    "\n",
    "                obj0hat = self.sde.obj0estimation(obj_hat, score_i, timestep)\n",
    "                obj0hat_masked = mask_adjs(obj0hat, flags) if is_adj else mask_x(obj0hat, flags)\n",
    "                \n",
    "                losses[i,:] = loss_fn(obj0hat_masked, **loss_kwargs)\n",
    "\n",
    "            losses_expanded = torch.argmin(losses, dim=0).view(1, obj.shape[0], 1, 1).expand(1, obj.shape[0], obj.shape[1], obj.shape[2]).to(obj.device)\n",
    "\n",
    "            return torch.gather(torch.stack(obj_hats, dim=0), 0, losses_expanded).squeeze(0), obj_mean\n",
    "\n",
    "        elif self.guidance_args.method == 'zero':\n",
    "            n_traj = self.guidance_args.n_traj\n",
    "\n",
    "            drift, diffusion = self.rsde.sde(x, adj, flags, t, is_adj=is_adj)\n",
    "            \n",
    "            obj_mean = obj + drift * dt\n",
    "\n",
    "            z = gen_noise(obj, flags, sym=is_adj)\n",
    "            obj = obj.clone() + diffusion[:, None, None] * np.sqrt(-dt) * z\n",
    "\n",
    "            score_no_noise = self.score_fn(x, obj.clone(), flags, t+dt) if is_adj else self.score_fn(obj.clone(), adj, flags, t)\n",
    "            obj0hat_no_noise = self.sde.obj0estimation(obj.clone(), score_no_noise, timestep)\n",
    "            obj0hat_masked_no_noise = mask_adjs(obj0hat_no_noise, flags) if is_adj else mask_x(obj0hat_no_noise, flags)\n",
    "            no_noise_loss = loss_fn(obj0hat_masked_no_noise, **loss_kwargs)\n",
    "\n",
    "            losses = torch.zeros(n_traj, obj.shape[0], device=obj.device)\n",
    "            noise_directions = []\n",
    "            for i in range(n_traj):\n",
    "                z = gen_noise(obj, flags, sym=is_adj)\n",
    "                obj_hat = obj.clone() + self.guidance_args.delta * z\n",
    "                noise_directions.append(z)\n",
    "                \n",
    "                score_i = self.score_fn(x, obj_hat, flags, t+dt) if is_adj else self.score_fn(obj_hat, adj, flags, t+dt)\n",
    "\n",
    "                obj0hat = self.sde.obj0estimation(obj_hat, score_i, timestep)\n",
    "                obj0hat_masked = mask_adjs(obj0hat, flags) if is_adj else mask_x(obj0hat, flags)\n",
    "                \n",
    "                losses[i,:] = loss_fn(obj0hat_masked, **loss_kwargs)\n",
    "\n",
    "            weights = (losses - no_noise_loss[None,:]) / self.guidance_args.delta # (diffusion[None, :] * np.sqrt(-dt))\n",
    "            directions = torch.stack(noise_directions, dim=0)\n",
    "            weighted_directions = (weights[:,:,None,None] * directions).mean(dim=0)\n",
    "\n",
    "            if self.guidance_args.clip_method == \"clip\":\n",
    "                # Gradient clipping\n",
    "                weighted_directions = torch.clamp(weighted_directions, -self.guidance_args.clip, self.guidance_args.clip)\n",
    "\n",
    "                obj = obj - self.guidance_args.lr_zero * weighted_directions\n",
    "            elif self.guidance_args.clip_method == \"norm\":\n",
    "                norm_grad_step = torch.linalg.norm(weighted_directions, dim=(1,2)) / (weighted_directions.shape[1] * weighted_directions.shape[2])\n",
    "                norm_grad_step = torch.where(norm_grad_step < 1e-7, torch.ones_like(norm_grad_step), norm_grad_step)\n",
    "\n",
    "                obj = obj - self.guidance_args.lr_zero * weighted_directions / norm_grad_step[:,None,None]\n",
    "            else:\n",
    "                obj = obj - self.guidance_args.lr_zero * weighted_directions\n",
    "            return obj, obj_mean\n",
    "        \n",
    "        elif self.guidance_args.method == 'loss':\n",
    "\n",
    "            with torch.enable_grad():\n",
    "\n",
    "                obj.requires_grad = True\n",
    "\n",
    "                score = self.score_fn(x, obj, flags, t) if is_adj else self.score_fn(obj, adj, flags, t)\n",
    "\n",
    "                obj0hat = self.sde.obj0estimation(obj, score, timestep)\n",
    "                obj0hat_masked = mask_adjs(obj0hat, flags) if is_adj else mask_x(obj0hat, flags)\n",
    "\n",
    "                loss = loss_fn(obj0hat_masked, **loss_kwargs).mean()\n",
    "\n",
    "                loss.backward()\n",
    "\n",
    "                obj_grad = obj.grad.detach().clone()\n",
    "                obj.grad = None\n",
    "\n",
    "            drift, diffusion = self.sde.sde(adj, t)\n",
    "            drift = drift - diffusion[:, None, None] ** 2 * score * (0.5 if self.rsde.probability_flow else 1.)\n",
    "            # -------- Set the diffusion function to zero for ODEs. --------\n",
    "            diffusion = 0. if self.rsde.probability_flow else diffusion\n",
    "\n",
    "            obj_mean = obj + drift * dt\n",
    "\n",
    "            z = gen_noise(obj, flags, sym=is_adj)\n",
    "            obj = obj_mean + diffusion[:, None, None] * np.sqrt(-dt) * z\n",
    "\n",
    "            if self.guidance_args.lr_guidance_method == 'adaptive':\n",
    "                obj -= self.guidance_args.lr_guidance / torch.abs(loss) * obj_grad\n",
    "            else:\n",
    "                obj -= self.guidance_args.lr_guidance * obj_grad\n",
    "\n",
    "            return obj, obj_mean\n",
    "        else:\n",
    "            raise NotImplementedError(f\"guidance method {self.guidance_args.method} not yet supported.\")\n",
    "\n",
    "            \n",
    "    def update_fn(self, x, adj, flags, t):\n",
    "        dt = -1. / self.rsde.N\n",
    "        timestep = (t[0] * (self.rsde.N - 1) / self.rsde.T).long()\n",
    "\n",
    "        var = x if self.obj == 'x' else adj\n",
    "\n",
    "        cond_guidance = self.guidance_args is not None and timestep > 0\n",
    "        if self.guidance_args is not None and 'from_t' in self.guidance_args:\n",
    "            cond_guidance = cond_guidance and timestep.item() < self.guidance_args.from_t\n",
    "            \n",
    "        if cond_guidance:\n",
    "            var, var_mean = self.guidance(x, adj, flags, t, is_adj=self.obj == 'adj')\n",
    "        else:\n",
    "            z = gen_noise(var, flags, sym=self.obj == 'adj')\n",
    "            drift, diffusion = self.rsde.sde(x, adj, flags, t, is_adj=self.obj == 'adj')\n",
    "            var_mean = var + drift * dt\n",
    "            var = var_mean + diffusion[:, None, None] * np.sqrt(-dt) * z\n",
    "        return var, var_mean\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def sample(predictor_x, corrector_x,\n",
    "           predictor_adj, corrector_adj,\n",
    "           init_x=None, init_adj=None, flags=None):\n",
    "    with torch.no_grad():\n",
    "        # -------- Initial sample --------\n",
    "        if init_x is not None:\n",
    "            x = init_x.clone()\n",
    "        else:\n",
    "            x = predictor_x.sde.prior_sampling(shape_x).to(device_id)\n",
    "        if init_adj is not None:\n",
    "            adj = init_adj.clone()\n",
    "        else:\n",
    "            adj = predictor_adj.sde.prior_sampling_sym(shape_adj).to(device_id)\n",
    "        \n",
    "        x = mask_x(x, flags)\n",
    "        adj = mask_adjs(adj, flags)\n",
    "        diff_steps = predictor_adj.sde.N\n",
    "        timesteps = torch.linspace(predictor_adj.sde.T, eps, diff_steps, device=device_id)\n",
    "\n",
    "        # -------- Reverse diffusion process --------\n",
    "        for i in trange(0, (diff_steps), desc = '[Sampling]', position = 1, leave=False):\n",
    "            t = timesteps[i]\n",
    "            vec_t = torch.ones(shape_adj[0], device=t.device) * t\n",
    "\n",
    "            _x = x\n",
    "            x, x_mean = corrector_x.update_fn(x, adj, flags, vec_t)\n",
    "            adj, adj_mean = corrector_adj.update_fn(_x, adj, flags, vec_t)\n",
    "            if torch.any(torch.isnan(adj)):\n",
    "                break\n",
    "\n",
    "            _x = x\n",
    "            x, x_mean = predictor_x.update_fn(x, adj, flags, vec_t)\n",
    "            adj, adj_mean = predictor_adj.update_fn(_x, adj, flags, vec_t)\n",
    "    samples_int = quantize_mol(adj)\n",
    "\n",
    "    # adj = torch.nn.functional.one_hot(torch.tensor(samples_int), num_classes=4).permute(0, 3, 1, 2)\n",
    "    x = torch.where(x > 0.5, 1, 0)\n",
    "    x = torch.concat([x, 1 - x.sum(dim=-1, keepdim=True)], dim=-1)      # 32, 9, 4 -> 32, 9, 5\n",
    "    return samples_int, x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def compute_group_assignments_tensor(adj, n_com):\n",
    "    gs = adjs_to_graphs(adj, True)\n",
    "    Zs = torch.zeros((adj.shape[0], adj.shape[1], n_com)).to(adj.device)\n",
    "    for g, graph in enumerate(gs):\n",
    "        Zs[g,:graph.number_of_nodes(),:] = compute_group_assignments(graph, n_com)\n",
    "    return Zs.permute(0,2,1)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Greedy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "method_Z = \"random\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "guidance_config = load_yaml_config(f'config_guidance/fairness/{method_Z}/greedy.yaml')\n",
    "guidance_args = edict({'method': 'greedy', 'obj': guidance_config['obj'], **guidance_config[configt.data.data.lower()]})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "score_fn_x = get_score_fn(sde_x, model_x, train=False, continuous=continuous)\n",
    "score_fn_adj = get_score_fn(sde_adj, model_adj, train=False, continuous=continuous)\n",
    "\n",
    "predictor_fn = ReverseDiffusionPredictor if predictor=='Reverse' else EulerMaruyamaPredictor \n",
    "corrector_fn = LangevinCorrector if corrector=='Langevin' else NoneCorrector\n",
    "\n",
    "predictor_obj_x = predictor_fn('x', sde_x, score_fn_x, probability_flow)\n",
    "corrector_obj_x = corrector_fn('x', sde_x, score_fn_x, snr, scale_eps, n_steps)\n",
    "\n",
    "predictor_obj_adj = predictor_fn('adj', sde_adj, score_fn_adj, probability_flow, guidance_args=guidance_args)\n",
    "corrector_obj_adj = corrector_fn('adj', sde_adj, score_fn_adj, snr, scale_eps, n_steps)\n",
    "\n",
    "adj_greedy, x_greedy = sample(predictor_obj_x, corrector_obj_x, predictor_obj_adj, corrector_obj_adj, init_x, init_adj, init_flags_iter)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "Zs_greedy = predictor_obj_adj.Z.clone().to(device_id)\n",
    "\n",
    "adj_greedy = torch.tensor(adj_greedy, device=device_id)\n",
    "\n",
    "dp1_greedy = compute_dp1(adj_greedy, Zs_greedy).mean().item()\n",
    "dp2_greedy = compute_dp2(adj_greedy, Zs_greedy).mean().item()\n",
    "dp2_greedy_std = compute_dp2(adj_greedy, Zs_greedy).std().item()\n",
    "nodedp1_greedy = compute_nodedp1(adj_greedy, Zs_greedy).mean().item()\n",
    "nodedp2_greedy = compute_nodedp2(adj_greedy, Zs_greedy).mean().item()\n",
    "nodedp2_greedy_std = compute_nodedp2(adj_greedy, Zs_greedy).std().item()\n",
    "across_comm_edges_greedy = count_across_community_edges(adj_greedy, Zs_greedy).mean().item()\n",
    "\n",
    "dp1_greedy, dp2_greedy, nodedp1_greedy, nodedp2_greedy, across_comm_edges_greedy"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Loss"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "guidance_config = load_yaml_config(f'config_guidance/fairness/{method_Z}/loss.yaml')\n",
    "guidance_args = edict({'method': 'loss', 'obj': guidance_config['obj'], **guidance_config[configt.data.data.lower()]})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "score_fn_x = get_score_fn(sde_x, model_x, train=False, continuous=continuous)\n",
    "score_fn_adj = get_score_fn(sde_adj, model_adj, train=False, continuous=continuous)\n",
    "\n",
    "predictor_fn = ReverseDiffusionPredictor if predictor=='Reverse' else EulerMaruyamaPredictor \n",
    "corrector_fn = LangevinCorrector if corrector=='Langevin' else NoneCorrector\n",
    "\n",
    "predictor_obj_x = predictor_fn('x', sde_x, score_fn_x, probability_flow)\n",
    "corrector_obj_x = corrector_fn('x', sde_x, score_fn_x, snr, scale_eps, n_steps)\n",
    "\n",
    "predictor_obj_adj = predictor_fn('adj', sde_adj, score_fn_adj, probability_flow, guidance_args=guidance_args)\n",
    "corrector_obj_adj = corrector_fn('adj', sde_adj, score_fn_adj, snr, scale_eps, n_steps)\n",
    "\n",
    "adj_loss, x_loss = sample(predictor_obj_x, corrector_obj_x, predictor_obj_adj, corrector_obj_adj, init_x, init_adj, init_flags_iter)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "Zs_loss = predictor_obj_adj.Z.clone().to(device_id)\n",
    "\n",
    "adj_loss = torch.tensor(adj_loss, device=device_id)\n",
    "\n",
    "dp1_loss = compute_dp1(adj_loss, Zs_loss).mean().item()\n",
    "dp2_loss = compute_dp2(adj_loss, Zs_loss).mean().item()\n",
    "dp2_loss_std = compute_dp2(adj_loss, Zs_loss).std().item()\n",
    "nodedp1_loss = compute_nodedp1(adj_loss, Zs_loss).mean().item()\n",
    "nodedp2_loss = compute_nodedp2(adj_loss, Zs_loss).mean().item()\n",
    "nodedp2_loss_std = compute_nodedp2(adj_loss, Zs_loss).std().item()\n",
    "across_comm_edges_loss = count_across_community_edges(adj_loss, Zs_loss).mean().item()\n",
    "\n",
    "dp1_loss, dp2_loss, nodedp1_loss, nodedp2_loss, across_comm_edges_loss"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Zero order optimization"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "guidance_config = load_yaml_config(f'config_guidance/fairness/{method_Z}/zero.yaml')\n",
    "guidance_args = edict({'method': 'zero', 'obj': guidance_config['obj'], **guidance_config[configt.data.data.lower()]})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "score_fn_x = get_score_fn(sde_x, model_x, train=False, continuous=continuous)\n",
    "score_fn_adj = get_score_fn(sde_adj, model_adj, train=False, continuous=continuous)\n",
    "\n",
    "predictor_fn = ReverseDiffusionPredictor if predictor=='Reverse' else EulerMaruyamaPredictor \n",
    "corrector_fn = LangevinCorrector if corrector=='Langevin' else NoneCorrector\n",
    "\n",
    "predictor_obj_x = predictor_fn('x', sde_x, score_fn_x, probability_flow)\n",
    "corrector_obj_x = corrector_fn('x', sde_x, score_fn_x, snr, scale_eps, n_steps)\n",
    "\n",
    "predictor_obj_adj = predictor_fn('adj', sde_adj, score_fn_adj, probability_flow, guidance_args=guidance_args)\n",
    "corrector_obj_adj = corrector_fn('adj', sde_adj, score_fn_adj, snr, scale_eps, n_steps)\n",
    "\n",
    "adj_zero, x_zero = sample(predictor_obj_x, corrector_obj_x, predictor_obj_adj, corrector_obj_adj, init_x, init_adj, init_flags_iter)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "Zs_zero = predictor_obj_adj.Z.clone().to(device_id)\n",
    "\n",
    "adj_zero = torch.tensor(adj_zero, device=device_id)\n",
    "\n",
    "dp1_zero = compute_dp1(adj_zero, Zs_zero).mean().item()\n",
    "dp2_zero = compute_dp2(adj_zero, Zs_zero).mean().item()\n",
    "dp2_zero_std = compute_dp2(adj_zero, Zs_zero).std().item()\n",
    "nodedp1_zero = compute_nodedp1(adj_zero, Zs_zero).mean().item()\n",
    "nodedp2_zero = compute_nodedp2(adj_zero, Zs_zero).mean().item()\n",
    "nodedp2_zero_std = compute_nodedp2(adj_zero, Zs_zero).std().item()\n",
    "across_comm_edges_zero = count_across_community_edges(adj_zero, Zs_zero).mean().item()\n",
    "\n",
    "dp1_zero, dp2_zero, nodedp1_zero, nodedp2_zero, across_comm_edges_zero"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## DiGress"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "sys.path.append('DiGress')\n",
    "sys.path.append('DiGress/src')\n",
    "\n",
    "from guided_sampling import sample_fair_digress"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "guidance_config_digress = load_yaml_config(f'config_guidance/fairness/{method_Z}/digress.yaml')\n",
    "digress_lambda = guidance_config_digress[configt.data.data.lower()]['guidance_lambda']\n",
    "digress_base_config_path = guidance_config_digress[configt.data.data.lower()]['base_config_path']\n",
    "digress_ckpt_path = guidance_config_digress[configt.data.data.lower()]['ckpt_path']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "x_digress, adj_digress, Zs_digress = sample_fair_digress(batch_size, digress_base_config_path, digress_ckpt_path, device_id, guidance_lambda=digress_lambda)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dp1_digress = compute_dp1(adj_digress.to(device_id), Zs_digress).mean().item()\n",
    "dp2_digress = compute_dp2(adj_digress.to(device_id), Zs_digress).mean().item()\n",
    "dp2_digress_std = compute_dp2(adj_digress.to(device_id), Zs_digress).std().item()\n",
    "nodedp1_digress = compute_nodedp1(adj_digress.to(device_id), Zs_digress).mean().item()\n",
    "nodedp2_digress = compute_nodedp2(adj_digress.to(device_id), Zs_digress).mean().item()\n",
    "nodedp2_digress_std = compute_nodedp2(adj_digress.to(device_id), Zs_digress).std().item()\n",
    "across_comm_edges_digress = count_across_community_edges(adj_digress.to(device_id), Zs_digress).mean().item()\n",
    "\n",
    "dp1_digress, dp2_digress, nodedp1_digress, nodedp2_digress, across_comm_edges_digress"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## DiGress without guidance"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "digress_lambda = 0.\n",
    "x_digress_unguided, adj_digress_unguided, Zs_digress_unguided = sample_fair_digress(batch_size, digress_base_config_path, digress_ckpt_path, device_id, guidance_lambda=digress_lambda)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dp1_digress_unguided = compute_dp1(adj_digress_unguided.to(device_id), Zs_digress_unguided).mean().item()\n",
    "dp2_digress_unguided = compute_dp2(adj_digress_unguided.to(device_id), Zs_digress_unguided).mean().item()\n",
    "dp2_digress_unguided_std = compute_dp2(adj_digress_unguided.to(device_id), Zs_digress_unguided).std().item()\n",
    "nodedp1_digress_unguided = compute_nodedp1(adj_digress_unguided.to(device_id), Zs_digress_unguided).mean().item()\n",
    "nodedp2_digress_unguided = compute_nodedp2(adj_digress_unguided.to(device_id), Zs_digress_unguided).mean().item()\n",
    "nodedp2_digress_unguided_std = compute_nodedp2(adj_digress_unguided.to(device_id), Zs_digress_unguided).std().item()\n",
    "across_comm_edges_digress_unguided = count_across_community_edges(adj_digress_unguided.to(device_id), Zs_digress_unguided).mean().item()\n",
    "\n",
    "dp1_digress_unguided, dp2_digress_unguided, nodedp1_digress_unguided, nodedp2_digress_unguided, across_comm_edges_digress_unguided"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Unconstrained"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "score_fn_x = get_score_fn(sde_x, model_x, train=False, continuous=continuous)\n",
    "score_fn_adj = get_score_fn(sde_adj, model_adj, train=False, continuous=continuous)\n",
    "\n",
    "predictor_fn = ReverseDiffusionPredictor if predictor=='Reverse' else EulerMaruyamaPredictor \n",
    "corrector_fn = LangevinCorrector if corrector=='Langevin' else NoneCorrector\n",
    "\n",
    "predictor_obj_x = predictor_fn('x', sde_x, score_fn_x, probability_flow)\n",
    "corrector_obj_x = corrector_fn('x', sde_x, score_fn_x, snr, scale_eps, n_steps)\n",
    "\n",
    "predictor_obj_adj = predictor_fn('adj', sde_adj, score_fn_adj, probability_flow)\n",
    "corrector_obj_adj = corrector_fn('adj', sde_adj, score_fn_adj, snr, scale_eps, n_steps)\n",
    "\n",
    "adj_uncons, x_uncons = sample(predictor_obj_x, corrector_obj_x, predictor_obj_adj, corrector_obj_adj, init_x, init_adj, init_flags_iter)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Assign half and half of the nodes randomly to each group\n",
    "n_elems = adj_uncons.shape[1] // 2\n",
    "# Create a template row with the correct number of elements\n",
    "template_row = np.array([0] * n_elems + [1] * (adj_uncons.shape[1] - n_elems))\n",
    "\n",
    "# Create an array where each row is a copy of the template row\n",
    "idxs_com = np.tile(template_row, (adj_uncons.shape[0], 1))\n",
    "\n",
    "# Apply a random permutation along the columns for each row\n",
    "for i in range(adj_uncons.shape[0]):\n",
    "    np.random.shuffle(idxs_com[i])\n",
    "Zs = torch.nn.functional.one_hot(torch.tensor(idxs_com), num_classes=guidance_args.get('n_com', 2)).float().to(device_id)\n",
    "Zs_uncons = Zs.permute(0,2,1)\n",
    "\n",
    "adj_uncons = torch.tensor(adj_uncons, device=device_id)\n",
    "\n",
    "dp1_uncons = compute_dp1(adj_uncons, Zs_uncons).mean().item()\n",
    "dp2_uncons = compute_dp2(adj_uncons, Zs_uncons).mean().item()\n",
    "dp2_uncons_std = compute_dp2(adj_uncons, Zs_uncons).std().item()\n",
    "nodedp1_uncons = compute_nodedp1(adj_uncons, Zs_uncons).mean().item()\n",
    "nodedp2_uncons = compute_nodedp2(adj_uncons, Zs_uncons).mean().item()\n",
    "nodedp2_uncons_std = compute_nodedp2(adj_uncons, Zs_uncons).std().item()\n",
    "across_comm_edges_uncons = count_across_community_edges(adj_uncons, Zs_uncons).mean().item()\n",
    "\n",
    "dp1_uncons, dp2_uncons, nodedp1_uncons, nodedp2_uncons, across_comm_edges_uncons"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Evaluation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df_results = pd.DataFrame({\n",
    "    'Method': ['GGDiff-G', 'GGDiff-C', 'GGDiff-Z', 'DiGress', 'Uncons. (DiGress)', 'Uncons. (GGDS)'],\n",
    "    'DP1': [dp1_loss, dp1_greedy, dp1_zero, dp1_digress, dp1_digress_unguided, dp1_uncons],\n",
    "    'DP2': [dp2_loss, dp2_greedy, dp2_zero, dp2_digress, dp2_digress_unguided, dp2_uncons],\n",
    "    'Node DP1': [nodedp1_loss, nodedp1_greedy, nodedp1_zero, nodedp1_digress, nodedp1_digress_unguided, nodedp1_uncons],\n",
    "    'Node DP2': [nodedp2_loss, nodedp2_greedy, nodedp2_zero, nodedp2_digress, nodedp2_digress_unguided, nodedp2_uncons]\n",
    "}).set_index('Method')\n",
    "# format the float numbers to 4 decimal places\n",
    "df_results.style.format({\n",
    "    'DP1': \"{:.4f}\",\n",
    "    'DP2': \"{:.4f}\",\n",
    "    'Node DP1': \"{:.4f}\",\n",
    "    'Node DP2': \"{:.4f}\"\n",
    "})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "min_vals = df_results.min(axis=0)\n",
    "\n",
    "print(\"\\\\begin{table}[h]\")\n",
    "print(\"\\\\centering\")\n",
    "print(\"\\\\caption{Results for the fair graph generation experiment.}\")\n",
    "print(\"\\\\label{tab:fair_graph_gen}\")\n",
    "print(\"\\\\begin{tabular}{cccc}\")\n",
    "print(\"\\\\toprule\")\n",
    "print(\"\\\\textbf{Method} & \\\\textbf{$\\\\Delta$ DP} & \\\\textbf{$\\\\Delta \\\\text{DP}_{\\\\text{node}}$ } \\\\\\\\\")\n",
    "print(\"\\\\midrule\")\n",
    "for index, row in df_results.iterrows():\n",
    "    print(f\"{index} & \", end=\"\")\n",
    "    for metric in ['DP2', 'Node DP2']:\n",
    "        val = row[metric]\n",
    "        if val == min_vals[metric]:\n",
    "            print(f\"\\\\textbf{{{val:.4f}}} & \", end=\"\")\n",
    "        else:\n",
    "            print(f\"{val:.4f} & \", end=\"\")\n",
    "    print(\"\\\\\\\\\")\n",
    "print(\"\\\\bottomrule\")\n",
    "print(\"\\\\end{tabular}\")\n",
    "print(\"\\\\end{table}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "digress-env",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
