{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import torch\n",
    "import networkx as nx\n",
    "import matplotlib.pyplot as plt\n",
    "from easydict import EasyDict as edict\n",
    "\n",
    "from losses import get_score_fn\n",
    "from solver_guidance import ReverseDiffusionPredictor, EulerMaruyamaPredictor, LangevinCorrector, NoneCorrector\n",
    "from utils.graph_utils import mask_adjs, mask_x\n",
    "from tqdm.notebook import trange\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import pickle\n",
    "\n",
    "from parsers.config import get_config\n",
    "\n",
    "from utils.logger import Logger, set_log, start_log, train_log, sample_log, check_log\n",
    "from utils.loader import load_ckpt, load_seed, load_device, load_model_from_ckpt, \\\n",
    "                         load_ema_from_ckpt, load_sde, load_yaml_config\n",
    "from utils.graph_utils import init_flags, quantize_mol\n",
    "from utils.mol_utils import gen_mol, mols_to_smiles, load_smiles, canonicalize_smiles, mols_to_nx\n",
    "from moses.metrics.metrics import get_all_metrics\n",
    "from evaluation.stats import eval_graph_list"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "sys.path.append('DiGress')\n",
    "sys.path.append('DiGress/src')\n",
    "\n",
    "from guided_sampling import sample_digress"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Data and preamble"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "device = load_device()\n",
    "device = [1]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "config_file = 'sample_qm9'\n",
    "seed = 0\n",
    "config = get_config(config_file, seed)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# -------- Load checkpoint --------\n",
    "ckpt_dict = load_ckpt(config, device)\n",
    "configt = ckpt_dict['config']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "load_seed(configt.seed)\n",
    "# train_graph_list, _ = load_data(configt, get_graph_list=True)\n",
    "with open(f'data/{configt.data.data.lower()}_test_nx.pkl', 'rb') as f:\n",
    "    test_graph_list = pickle.load(f)                                   # for NSPDK MMD"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_smiles, test_smiles = load_smiles(configt.data.data)\n",
    "train_smiles, test_smiles = canonicalize_smiles(train_smiles), canonicalize_smiles(test_smiles)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "log_folder_name, log_dir, _ = set_log(configt, is_train=False)\n",
    "log_name = f\"{config.ckpt}-sample-guidance\"\n",
    "logger = Logger(str(os.path.join(log_dir, f'{log_name}.log')), mode='a')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if not check_log(log_folder_name, log_name):\n",
    "    logger.log(f'{log_name}')\n",
    "    start_log(logger, configt)\n",
    "    train_log(logger, configt)\n",
    "sample_log(logger, config)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# -------- Load models --------\n",
    "model_x = load_model_from_ckpt(ckpt_dict['params_x'], ckpt_dict['x_state_dict'], device)\n",
    "model_adj = load_model_from_ckpt(ckpt_dict['params_adj'], ckpt_dict['adj_state_dict'], device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if config.sample.use_ema:\n",
    "    ema_x = load_ema_from_ckpt(model_x, ckpt_dict['ema_x'], configt.train.ema)\n",
    "    ema_adj = load_ema_from_ckpt(model_adj, ckpt_dict['ema_adj'], configt.train.ema)\n",
    "    \n",
    "    ema_x.copy_to(model_x.parameters())\n",
    "    ema_adj.copy_to(model_adj.parameters())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(f'GEN SEED: {config.sample.seed}')\n",
    "load_seed(config.sample.seed)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Generation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sde_x = load_sde(configt.sde.x)\n",
    "sde_adj = load_sde(configt.sde.adj)\n",
    "max_node_num  = configt.data.max_node_num\n",
    "\n",
    "device_id = f'cuda:{device[0]}' if isinstance(device, list) else device"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "continuous = True\n",
    "\n",
    "predictor = config.sampler.predictor\n",
    "corrector = config.sampler.corrector\n",
    "\n",
    "snr=config.sampler.snr\n",
    "scale_eps=config.sampler.scale_eps\n",
    "n_steps=config.sampler.n_steps\n",
    "\n",
    "probability_flow = config.sample.probability_flow\n",
    "eps = config.sample.eps\n",
    "denoise = config.sample.noise_removal"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if configt.data.data in ['QM9']: # Zinc250k gives OOM with 10000 samples generated\n",
    "    batch_size = 10000\n",
    "    # batch_size = 1000\n",
    "else:\n",
    "    batch_size = configt.data.batch_size\n",
    "shape_x = (batch_size, max_node_num, configt.data.max_feat_num)\n",
    "shape_adj = (batch_size, max_node_num, max_node_num)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "max_edges = np.percentile([g.number_of_edges() for g in test_graph_list], 20)\n",
    "max_edges"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from utils.graph_utils import graphs_to_tensor, node_flags\n",
    "\n",
    "max_node_num = config.data.max_node_num\n",
    "graph_tensor = graphs_to_tensor(test_graph_list, max_node_num)\n",
    "idx = np.random.randint(0, len(test_graph_list), batch_size)\n",
    "idx_sample = 0 # To plot later, we ensure that there are at least n_plot graphs of idx_sample sample\n",
    "n_plot = 5\n",
    "idx = np.array([idx_sample]*n_plot + list(np.random.randint(0, len(test_graph_list), batch_size-n_plot)))\n",
    "init_flags_iter = node_flags(graph_tensor[idx]).to(device_id)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "pct_obs = 0.5\n",
    "same_graph_same_obs = True\n",
    "symmetric = True\n",
    "remove_self_loops = True\n",
    "observed = \"edges\" # \"entries\" \"edges\"\n",
    " \n",
    "ground_truth_adjs = graph_tensor[idx].to(device_id)\n",
    "\n",
    "if observed == \"entries\":\n",
    "    if same_graph_same_obs:\n",
    "        random_samps = torch.rand(len(test_graph_list), max_node_num, max_node_num)[idx].to(device_id)\n",
    "    else:\n",
    "        random_samps = torch.rand(*ground_truth_adjs.shape).to(device_id)\n",
    "    \n",
    "    if symmetric:\n",
    "        random_samps = (random_samps + random_samps.transpose(-1, -2))/2\n",
    "    \n",
    "    bool_tensor = random_samps < pct_obs\n",
    "    if remove_self_loops:\n",
    "        bool_tensor = bool_tensor & ~torch.eye(max_node_num, device=device_id).bool()\n",
    "    idx_observed = torch.where(bool_tensor)\n",
    "elif observed == \"edges\":\n",
    "    idx_edges = torch.where(ground_truth_adjs != 0)\n",
    "    n_edges_tot = idx_edges[0].shape[0]\n",
    "    idx_obs = torch.randperm(n_edges_tot)[:int(pct_obs*n_edges_tot)]\n",
    "    idx_observed = (idx_edges[0][idx_obs], idx_edges[1][idx_obs], idx_edges[2][idx_obs])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "idx_observed = (idx_observed[0].to(device_id), idx_observed[1].to(device_id), idx_observed[2].to(device_id))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "init_x = sde_x.prior_sampling(shape_x).to(device_id)\n",
    "init_adj = sde_adj.prior_sampling_sym(shape_adj).to(device_id)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "init_adj_ground_truth = True\n",
    "if init_adj_ground_truth:\n",
    "    init_adj[idx_observed] = ground_truth_adjs[idx_observed]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def sample(predictor_x, corrector_x,\n",
    "           predictor_adj, corrector_adj,\n",
    "           init_x=None, init_adj=None, flags=None,\n",
    "           ground_truth_adjs=None, idx_observed=None):\n",
    "    with torch.no_grad():\n",
    "        # -------- Initial sample --------\n",
    "        if init_x is not None:\n",
    "            x = init_x.clone()\n",
    "        else:\n",
    "            x = predictor_x.sde.prior_sampling(shape_x).to(device_id)\n",
    "        if init_adj is not None:\n",
    "            adj = init_adj.clone()\n",
    "        else:\n",
    "            adj = predictor_adj.sde.prior_sampling_sym(shape_adj).to(device_id)\n",
    "        \n",
    "        x = mask_x(x, flags)\n",
    "        adj = mask_adjs(adj, flags)\n",
    "        diff_steps = predictor_adj.sde.N\n",
    "        timesteps = torch.linspace(predictor_adj.sde.T, eps, diff_steps, device=device_id)\n",
    "\n",
    "        # -------- Reverse diffusion process --------\n",
    "        for i in trange(0, (diff_steps), desc = '[Sampling]', position = 1, leave=False):\n",
    "            t = timesteps[i]\n",
    "            vec_t = torch.ones(shape_adj[0], device=t.device) * t\n",
    "\n",
    "            _x = x\n",
    "            x, x_mean = corrector_x.update_fn(x, adj, flags, vec_t)\n",
    "            adj, adj_mean = corrector_adj.update_fn(_x, adj, flags, vec_t)\n",
    "            if torch.any(torch.isnan(adj)):\n",
    "                break\n",
    "\n",
    "            _x = x\n",
    "            x, x_mean = predictor_x.update_fn(x, adj, flags, vec_t)\n",
    "            adj, adj_mean = predictor_adj.update_fn(_x, adj, flags, vec_t)\n",
    "\n",
    "            if ground_truth_adjs is not None and idx_observed is not None:\n",
    "                adj[idx_observed] = ground_truth_adjs[idx_observed]\n",
    "\n",
    "    samples_int = quantize_mol(adj)\n",
    "\n",
    "    x = torch.where(x > 0.5, 1, 0)\n",
    "    x = torch.concat([x, 1 - x.sum(dim=-1, keepdim=True)], dim=-1)      # 32, 9, 4 -> 32, 9, 5\n",
    "    return samples_int, x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_samples(gen_graph_list, gen_adjs, test_graph_list, idx_order, idx_observed, idx_sample, n_graphs, ax=None, layout=None, method=None):\n",
    "\n",
    "    test_graph_sample = nx.convert_node_labels_to_integers(test_graph_list[idx_sample])\n",
    "    n_nodes = test_graph_sample.number_of_nodes()\n",
    "\n",
    "    if layout is None:\n",
    "        layout = nx.drawing.layout.spring_layout(test_graph_sample)\n",
    "        \n",
    "    # Get the adjacency matrix of the test graph\n",
    "    test_adj = graph_tensor[idx_sample].cpu().numpy()\n",
    "    # Extract edges from the test graph\n",
    "    test_edges = list(test_graph_sample.edges())\n",
    "    \n",
    "    graphs_represent = np.nonzero(idx_order == idx_sample)[0]\n",
    "    n_graphs = min(n_graphs, graphs_represent.shape[0])\n",
    "    if graphs_represent.shape[0] > n_graphs:\n",
    "        graphs_represent = graphs_represent[:n_graphs]\n",
    "\n",
    "    if ax is None:\n",
    "        f, ax = plt.subplots(1, 1 + n_graphs, figsize=(15, 5))\n",
    "    \n",
    "        # Draw the test graph\n",
    "        nx.draw(test_graph_sample, pos=layout, ax=ax[0], with_labels=True)\n",
    "        ax[0].set_title(\"Test Graph\")\n",
    "        axs = ax[1:]\n",
    "    else:\n",
    "        axs = ax\n",
    "        \n",
    "    for i in range(n_graphs):\n",
    "        graph_idx = graphs_represent[i]\n",
    "        gen_graph = nx.Graph()# copy.deepcopy(gen_graph_list[graph_idx])\n",
    "        # if gen_graph.number_of_nodes() < n_nodes:\n",
    "        gen_graph.add_nodes_from(np.arange(n_nodes))\n",
    "        # Get the adjacency matrix of the generated graph\n",
    "        # gen_adj = nx.to_numpy_array(gen_graph)\n",
    "        gen_adj = gen_adjs[graph_idx]\n",
    "        np.fill_diagonal(gen_adj, 0.)\n",
    "    \n",
    "        entries_to_check = [\n",
    "            (u.item(), v.item()) for u, v in zip(idx_observed[1][idx_observed[0] == graph_idx], \n",
    "                                   idx_observed[2][idx_observed[0] == graph_idx])\n",
    "                    if u < n_nodes and v < n_nodes\n",
    "        ]\n",
    "    \n",
    "        # Determine edge colors based on adjacency matrix comparison\n",
    "        nonzero_entries = np.nonzero(gen_adj)\n",
    "        gen_graph_edges = []\n",
    "        for j in range(nonzero_entries[0].shape[0]):\n",
    "            if (nonzero_entries[0][j], nonzero_entries[1][j]) not in gen_graph_edges and \\\n",
    "                  (nonzero_entries[1][j], nonzero_entries[0][j]) not in gen_graph_edges and \\\n",
    "                  nonzero_entries[0][j] < n_nodes and nonzero_entries[1][j] < n_nodes:\n",
    "                gen_graph_edges.append((min(nonzero_entries[0][j], nonzero_entries[1][j]), max(nonzero_entries[0][j], nonzero_entries[1][j])))\n",
    "        # gen_graph.add_edges_from(gen_graph_edges)\n",
    "        edge_colors = []\n",
    "        new_edges = []\n",
    "        new_edge_colors = []\n",
    "        n_correct = 0\n",
    "        n_incorrect = 0\n",
    "        for u in range(n_nodes):\n",
    "            for v in range(u + 1, n_nodes):\n",
    "                if (u, v) in entries_to_check or (v, u) in entries_to_check:\n",
    "                    if (u, v) in gen_graph_edges or (v, u) in gen_graph_edges:\n",
    "                        if test_adj[u, v] == gen_adj[u, v] or test_adj[v, u] == gen_adj[v, u]:  # Matching edge\n",
    "                            edge_colors.append('green')\n",
    "                            n_correct += 1\n",
    "                        else:  # Non-matching edge\n",
    "                            edge_colors.append('red')\n",
    "                            n_incorrect += 1\n",
    "                    else:\n",
    "                        new_edges.append((u, v))\n",
    "                        # gen_graph.add_edge(u, v)\n",
    "                        if test_adj[u, v] == gen_adj[u, v] or test_adj[v, u] == gen_adj[v, u]:\n",
    "                            new_edge_colors.append('green')\n",
    "                            n_correct += 1\n",
    "                        else:\n",
    "                            new_edge_colors.append('red')\n",
    "                            n_incorrect += 1\n",
    "                elif (u, v) in gen_graph_edges or (v, u) in gen_graph_edges:\n",
    "                    edge_colors.append('black')\n",
    "        print(f\"Graph {i + 1}: {n_correct} correct edges, {n_incorrect} incorrect edges\")\n",
    "        assert len(edge_colors) == len(gen_graph_edges), f\"{edge_colors} {gen_graph_edges}\"\n",
    "        # Draw the generated graph with colored edges\n",
    "        nx.draw_networkx_nodes(gen_graph, pos=layout, ax=axs[i])\n",
    "        nx.draw_networkx_labels(gen_graph, pos=layout, ax=axs[i])\n",
    "        nx.draw_networkx_edges(gen_graph, pos=layout, ax=axs[i], edgelist=new_edges, edge_color=new_edge_colors, style='dotted')\n",
    "        nx.draw_networkx_edges(gen_graph, pos=layout, ax=axs[i], edgelist=gen_graph_edges, edge_color=edge_colors)\n",
    "        # ax[i].set_title(f\"Generated Graph {i + 1}\")\n",
    "    if method is not None:\n",
    "        axs[0].set_ylabel(method, fontsize=24)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## DiGress"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "guidance_config_digress = load_yaml_config(f'config_guidance/incomplete/{observed}/digress.yaml')\n",
    "digress_lambda = guidance_config_digress[configt.data.data.lower()]['guidance_lambda']\n",
    "digress_base_config_path = guidance_config_digress[configt.data.data.lower()]['base_config_path']\n",
    "digress_ckpt_path = guidance_config_digress[configt.data.data.lower()]['ckpt_path']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "x_digress, adj_digress = sample_digress(idx_observed, ground_truth_adjs, batch_size, digress_base_config_path, digress_ckpt_path, device_id, guidance_lambda=digress_lambda)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "acc_digress = (adj_digress[idx_observed] == ground_truth_adjs[idx_observed]).sum().item() / idx_observed[0].shape[0]\n",
    "acc_digress"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "x_digress[x_digress < 0] = 0\n",
    "x_digress = torch.nn.functional.one_hot(x_digress.long(), num_classes=configt.data.max_feat_num)\n",
    "x_digress = torch.concat([x_digress, 1 - x_digress.sum(dim=-1, keepdim=True)], dim=-1)      # 32, 9, 4 -> 32, 9, 5\n",
    "\n",
    "adj_digress_mod = adj_digress.clone() - 1\n",
    "adj_digress_mod[adj_digress_mod < 0] = 3      # 0, 1, 2, 3 (no, S, D, T) -> 3, 0, 1, 2\n",
    "adj_onehot_digress = torch.nn.functional.one_hot(adj_digress_mod.long(), num_classes=4).permute(0, 3, 1, 2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "gen_mols, num_mols_wo_correction = gen_mol(x_digress, adj_onehot_digress, configt.data.data)\n",
    "gen_graph_list_digress = mols_to_nx(gen_mols)\n",
    "gen_smiles = mols_to_smiles(gen_mols)\n",
    "gen_smiles = [smi for smi in gen_smiles if len(smi)]\n",
    "nx.draw(gen_graph_list_digress[0], with_labels=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "scores_digress = get_all_metrics(gen=gen_smiles, k=len(gen_smiles), device=device_id, n_jobs=8, test=test_smiles)\n",
    "scores_nspdk_digress = eval_graph_list(test_graph_list, gen_graph_list_digress, methods=['nspdk'])['nspdk']\n",
    "scores_digress, scores_nspdk_digress"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "satisfies_digress = np.mean([g.number_of_edges() <= max_edges for g in gen_graph_list_digress])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## DiGress unguided"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "guidance_lambda = 0.0\n",
    "x_digress_unguided, adj_digress_unguided = sample_digress(idx_observed, ground_truth_adjs, batch_size, base_config_path, ckpt_path, device_id, guidance_lambda=guidance_lambda)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "acc_digress_unguided = (adj_digress_unguided[idx_observed] == ground_truth_adjs[idx_observed]).sum().item() / idx_observed[0].shape[0]\n",
    "acc_digress_unguided"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "x_digress_unguided[x_digress_unguided < 0] = 0\n",
    "x_digress_unguided = torch.nn.functional.one_hot(x_digress_unguided.long(), num_classes=configt.data.max_feat_num)\n",
    "x_digress_unguided = torch.concat([x_digress_unguided, 1 - x_digress_unguided.sum(dim=-1, keepdim=True)], dim=-1)      # 32, 9, 4 -> 32, 9, 5\n",
    "\n",
    "adj_digress_mod = adj_digress_unguided.clone() - 1\n",
    "adj_digress_mod[adj_digress_mod < 0] = 3      # 0, 1, 2, 3 (no, S, D, T) -> 3, 0, 1, 2\n",
    "adj_onehot_digress_unguided = torch.nn.functional.one_hot(adj_digress_mod.long(), num_classes=4).permute(0, 3, 1, 2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "gen_mols, num_mols_wo_correction = gen_mol(x_digress_unguided, adj_onehot_digress_unguided, configt.data.data)\n",
    "gen_graph_list_digress_unguided = mols_to_nx(gen_mols)\n",
    "gen_smiles = mols_to_smiles(gen_mols)\n",
    "gen_smiles = [smi for smi in gen_smiles if len(smi)]\n",
    "nx.draw(gen_graph_list_digress_unguided[0], with_labels=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "scores_digress_unguided = get_all_metrics(gen=gen_smiles, k=len(gen_smiles), device=device_id, n_jobs=8, test=test_smiles)\n",
    "scores_nspdk_digress_unguided = eval_graph_list(test_graph_list, gen_graph_list_digress_unguided, methods=['nspdk'])['nspdk']\n",
    "scores_digress_unguided, scores_nspdk_digress_unguided"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "satisfies_digress_unguided = np.mean([g.number_of_edges() <= max_edges for g in gen_graph_list_digress_unguided])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Greedy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "guidance_config = load_yaml_config(f'config_guidance/constrained/nedges/greedy.yaml')\n",
    "guidance_args = edict({'method': 'greedy', 'obj': guidance_config['obj'], **guidance_config[configt.data.data.lower()]})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "score_fn_x = get_score_fn(sde_x, model_x, train=False, continuous=continuous)\n",
    "score_fn_adj = get_score_fn(sde_adj, model_adj, train=False, continuous=continuous)\n",
    "\n",
    "predictor_fn = ReverseDiffusionPredictor if predictor=='Reverse' else EulerMaruyamaPredictor \n",
    "corrector_fn = LangevinCorrector if corrector=='Langevin' else NoneCorrector\n",
    "\n",
    "predictor_obj_x = predictor_fn('x', sde_x, score_fn_x, probability_flow)\n",
    "corrector_obj_x = corrector_fn('x', sde_x, score_fn_x, snr, scale_eps, n_steps)\n",
    "\n",
    "predictor_obj_adj = predictor_fn('adj', sde_adj, score_fn_adj, probability_flow, guidance_args=guidance_args)\n",
    "corrector_obj_adj = corrector_fn('adj', sde_adj, score_fn_adj, snr, scale_eps, n_steps)\n",
    "\n",
    "adj_greedy, x_greedy = sample(predictor_obj_x, corrector_obj_x,\n",
    "                              predictor_obj_adj, corrector_obj_adj,\n",
    "                              init_x, init_adj, init_flags_iter,\n",
    "                              ground_truth_adjs, idx_observed)\n",
    "\n",
    "adj_greedy_mod = adj_greedy.copy() - 1\n",
    "adj_greedy_mod[adj_greedy_mod == -1] = 3      # 0, 1, 2, 3 (no, S, D, T) -> 3, 0, 1, 2\n",
    "adj_onehot_greedy = torch.nn.functional.one_hot(torch.tensor(adj_greedy_mod), num_classes=4).permute(0, 3, 1, 2)\n",
    "\n",
    "gen_mols, num_mols_wo_correction = gen_mol(x_greedy, adj_onehot_greedy, configt.data.data)\n",
    "\n",
    "gen_smiles = mols_to_smiles(gen_mols)\n",
    "gen_smiles = [smi for smi in gen_smiles if len(smi)]\n",
    "\n",
    "gen_graph_list_greedy = mols_to_nx(gen_mols)\n",
    "\n",
    "scores_greedy = get_all_metrics(gen=gen_smiles, k=len(gen_smiles), device=device_id, n_jobs=8, test=test_smiles)\n",
    "scores_nspdk_greedy = eval_graph_list(test_graph_list, gen_graph_list_greedy, methods=['nspdk'])['nspdk']\n",
    "\n",
    "acc_greedy = (torch.tensor(adj_greedy).to(device_id)[idx_observed] == ground_truth_adjs[idx_observed]).sum() / idx_observed[0].shape[0]\n",
    "acc_greedy = acc_greedy.item()\n",
    "\n",
    "satisfies_greedy = np.mean([g.number_of_edges() <= max_edges for g in gen_graph_list_greedy])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Loss"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "guidance_config = load_yaml_config(f'config_guidance/constrained/nedges/loss.yaml')\n",
    "guidance_args = edict({'method': 'loss', 'obj': guidance_config['obj'], **guidance_config[configt.data.data.lower()]})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "score_fn_x = get_score_fn(sde_x, model_x, train=False, continuous=continuous)\n",
    "score_fn_adj = get_score_fn(sde_adj, model_adj, train=False, continuous=continuous)\n",
    "\n",
    "predictor_fn = ReverseDiffusionPredictor if predictor=='Reverse' else EulerMaruyamaPredictor \n",
    "corrector_fn = LangevinCorrector if corrector=='Langevin' else NoneCorrector\n",
    "\n",
    "predictor_obj_x = predictor_fn('x', sde_x, score_fn_x, probability_flow)\n",
    "corrector_obj_x = corrector_fn('x', sde_x, score_fn_x, snr, scale_eps, n_steps)\n",
    "\n",
    "predictor_obj_adj = predictor_fn('adj', sde_adj, score_fn_adj, probability_flow, guidance_args=guidance_args)\n",
    "corrector_obj_adj = corrector_fn('adj', sde_adj, score_fn_adj, snr, scale_eps, n_steps)\n",
    "\n",
    "adj_loss, x_loss = sample(predictor_obj_x, corrector_obj_x,\n",
    "                              predictor_obj_adj, corrector_obj_adj,\n",
    "                              init_x, init_adj, init_flags_iter,\n",
    "                              ground_truth_adjs, idx_observed)\n",
    "\n",
    "adj_loss_mod = adj_loss.copy() - 1\n",
    "adj_loss_mod[adj_loss_mod == -1] = 3      # 0, 1, 2, 3 (no, S, D, T) -> 3, 0, 1, 2\n",
    "adj_onehot_loss = torch.nn.functional.one_hot(torch.tensor(adj_loss_mod), num_classes=4).permute(0, 3, 1, 2)\n",
    "\n",
    "gen_mols, num_mols_wo_correction = gen_mol(x_loss, adj_onehot_loss, configt.data.data)\n",
    "\n",
    "gen_smiles = mols_to_smiles(gen_mols)\n",
    "gen_smiles = [smi for smi in gen_smiles if len(smi)]\n",
    "\n",
    "gen_graph_list_loss = mols_to_nx(gen_mols)\n",
    "\n",
    "scores_loss = get_all_metrics(gen=gen_smiles, k=len(gen_smiles), device=device_id, n_jobs=8, test=test_smiles)\n",
    "scores_nspdk_loss = eval_graph_list(test_graph_list, gen_graph_list_loss, methods=['nspdk'])['nspdk']\n",
    "\n",
    "acc_loss = (torch.tensor(adj_loss).to(device_id)[idx_observed] == ground_truth_adjs[idx_observed]).sum() / idx_observed[0].shape[0]\n",
    "acc_loss = acc_loss.item()\n",
    "\n",
    "satisfies_loss = np.mean([g.number_of_edges() <= max_edges for g in gen_graph_list_loss])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Zero"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "guidance_config = load_yaml_config(f'config_guidance/constrained/nedges/zero.yaml')\n",
    "guidance_args = edict({'method': 'zero', 'obj': guidance_config['obj'], **guidance_config[configt.data.data.lower()]})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "score_fn_x = get_score_fn(sde_x, model_x, train=False, continuous=continuous)\n",
    "score_fn_adj = get_score_fn(sde_adj, model_adj, train=False, continuous=continuous)\n",
    "\n",
    "predictor_fn = ReverseDiffusionPredictor if predictor=='Reverse' else EulerMaruyamaPredictor \n",
    "corrector_fn = LangevinCorrector if corrector=='Langevin' else NoneCorrector\n",
    "\n",
    "predictor_obj_x = predictor_fn('x', sde_x, score_fn_x, probability_flow)\n",
    "corrector_obj_x = corrector_fn('x', sde_x, score_fn_x, snr, scale_eps, n_steps)\n",
    "\n",
    "predictor_obj_adj = predictor_fn('adj', sde_adj, score_fn_adj, probability_flow, guidance_args=guidance_args)\n",
    "corrector_obj_adj = corrector_fn('adj', sde_adj, score_fn_adj, snr, scale_eps, n_steps)\n",
    "\n",
    "adj_zero, x_zero = sample(predictor_obj_x, corrector_obj_x,\n",
    "                              predictor_obj_adj, corrector_obj_adj,\n",
    "                              init_x, init_adj, init_flags_iter,\n",
    "                              ground_truth_adjs, idx_observed)\n",
    "\n",
    "adj_zero_mod = adj_zero.copy() - 1\n",
    "adj_zero_mod[adj_zero_mod == -1] = 3      # 0, 1, 2, 3 (no, S, D, T) -> 3, 0, 1, 2\n",
    "adj_onehot_zero = torch.nn.functional.one_hot(torch.tensor(adj_zero_mod), num_classes=4).permute(0, 3, 1, 2)\n",
    "\n",
    "gen_mols, num_mols_wo_correction = gen_mol(x_zero, adj_onehot_zero, configt.data.data)\n",
    "\n",
    "gen_smiles = mols_to_smiles(gen_mols)\n",
    "gen_smiles = [smi for smi in gen_smiles if len(smi)]\n",
    "\n",
    "gen_graph_list_zero = mols_to_nx(gen_mols)\n",
    "\n",
    "scores_zero = get_all_metrics(gen=gen_smiles, k=len(gen_smiles), device=device_id, n_jobs=8, test=test_smiles)\n",
    "scores_nspdk_zero = eval_graph_list(test_graph_list, gen_graph_list_zero, methods=['nspdk'])['nspdk']\n",
    "\n",
    "acc_zero = (torch.tensor(adj_zero).to(device_id)[idx_observed] == ground_truth_adjs[idx_observed]).sum() / idx_observed[0].shape[0]\n",
    "acc_zero = acc_zero.item()\n",
    "\n",
    "satisfies_zero = np.mean([g.number_of_edges() <= max_edges for g in gen_graph_list_zero])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## No Guidance"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "score_fn_x = get_score_fn(sde_x, model_x, train=False, continuous=continuous)\n",
    "score_fn_adj = get_score_fn(sde_adj, model_adj, train=False, continuous=continuous)\n",
    "\n",
    "predictor_fn = ReverseDiffusionPredictor if predictor=='Reverse' else EulerMaruyamaPredictor \n",
    "corrector_fn = LangevinCorrector if corrector=='Langevin' else NoneCorrector\n",
    "\n",
    "predictor_obj_x = predictor_fn('x', sde_x, score_fn_x, probability_flow)\n",
    "corrector_obj_x = corrector_fn('x', sde_x, score_fn_x, snr, scale_eps, n_steps)\n",
    "\n",
    "predictor_obj_adj = predictor_fn('adj', sde_adj, score_fn_adj, probability_flow)\n",
    "corrector_obj_adj = corrector_fn('adj', sde_adj, score_fn_adj, snr, scale_eps, n_steps)\n",
    "\n",
    "adj_noguid, x_noguid = sample(predictor_obj_x, corrector_obj_x,\n",
    "                              predictor_obj_adj, corrector_obj_adj,\n",
    "                              init_x, init_adj, init_flags_iter,\n",
    "                              ground_truth_adjs, idx_observed)\n",
    "\n",
    "adj_noguid_mod = adj_noguid.copy() - 1\n",
    "adj_noguid_mod[adj_noguid_mod == -1] = 3      # 0, 1, 2, 3 (no, S, D, T) -> 3, 0, 1, 2\n",
    "adj_onehot_noguid = torch.nn.functional.one_hot(torch.tensor(adj_noguid_mod), num_classes=4).permute(0, 3, 1, 2)\n",
    "\n",
    "gen_mols, num_mols_wo_correction = gen_mol(x_noguid, adj_onehot_noguid, configt.data.data)\n",
    "\n",
    "gen_smiles = mols_to_smiles(gen_mols)\n",
    "gen_smiles = [smi for smi in gen_smiles if len(smi)]\n",
    "\n",
    "gen_graph_list_noguid = mols_to_nx(gen_mols)\n",
    "\n",
    "scores_noguid = get_all_metrics(gen=gen_smiles, k=len(gen_smiles), device=device_id, n_jobs=8, test=test_smiles)\n",
    "scores_nspdk_noguid = eval_graph_list(test_graph_list, gen_graph_list_noguid, methods=['nspdk'])['nspdk']\n",
    "\n",
    "acc_noguid = (torch.tensor(adj_noguid).to(device_id)[idx_observed] == ground_truth_adjs[idx_observed]).sum() / idx_observed[0].shape[0]\n",
    "acc_noguid = acc_noguid.item()\n",
    "\n",
    "satisfies_noguid = np.mean([g.number_of_edges() <= max_edges for g in gen_graph_list_noguid])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "n_edges_greedy = np.array([g.number_of_edges() for g in gen_graph_list_greedy])\n",
    "n_edges_loss = np.array([g.number_of_edges() for g in gen_graph_list_loss])\n",
    "n_edges_zero = np.array([g.number_of_edges() for g in gen_graph_list_zero])\n",
    "n_edges_uncons = np.array([g.number_of_edges() for g in gen_graph_list_noguid])\n",
    "n_edges_digress = np.array([g.number_of_edges() for g in gen_graph_list_digress])\n",
    "n_edges_digress_unguided = np.array([g.number_of_edges() for g in gen_graph_list_digress_unguided])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Evaluation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df_results = pd.DataFrame({\n",
    "    'Method': ['GGDiff-G', 'GGDiff-C', 'GGDiff-Z', 'Uncons. (GDSS)', 'DiGress', 'Uncons. (DiGress)'],\n",
    "    '% Satisfies Constraints': [100*satisfies_loss, 100*satisfies_greedy, 100*satisfies_zero, 100*satisfies_noguid, 100*satisfies_digress, 100*satisfies_digress_unguided],\n",
    "    '% Unique': [100*scores_loss['unique@10000'], 100*scores_greedy['unique@10000'], 100*scores_zero['unique@10000'], 100*scores_noguid['unique@10000'], 100*scores_digress['unique@10000'], 100*scores_digress_unguided['unique@10000']],\n",
    "    'Num. Edges': [f\"${n_edges_loss.mean():.1f}\\\\pm {n_edges_loss.std():.1f}$\", f\"${n_edges_greedy.mean():.1f} \\\\pm {n_edges_greedy.std():.1f}$\", f\"${n_edges_zero.mean():.1f}\\\\pm {n_edges_zero.std():.1f}$\", f\"${n_edges_uncons.mean():.1f}\\\\pm {n_edges_uncons.std():.1f}$\", f\"${n_edges_digress.mean():.1f}\\\\pm {n_edges_digress.std():.1f}$\", f\"${n_edges_digress_unguided.mean():.1f}\\\\pm {n_edges_digress_unguided.std():.1f}$\"],\n",
    "}).set_index('Method').round(4)\n",
    "df_results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(df_results.style.format(precision=2).to_latex())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "digress-env",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
