{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import torch\n",
    "import networkx as nx\n",
    "import matplotlib.pyplot as plt\n",
    "from easydict import EasyDict as edict\n",
    "import copy\n",
    "\n",
    "from losses import get_score_fn\n",
    "from solver_guidance import ReverseDiffusionPredictor, EulerMaruyamaPredictor, LangevinCorrector, NoneCorrector\n",
    "from utils.graph_utils import mask_adjs, mask_x\n",
    "from tqdm.notebook import trange\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import pickle\n",
    "\n",
    "from parsers.config import get_config\n",
    "\n",
    "from utils.logger import Logger, set_log, start_log, train_log, sample_log, check_log\n",
    "from utils.loader import load_ckpt, load_seed, load_device, load_model_from_ckpt, \\\n",
    "                         load_ema_from_ckpt, load_sde, load_yaml_config\n",
    "from utils.graph_utils import init_flags, quantize_mol\n",
    "from utils.mol_utils import gen_mol, mols_to_smiles, load_smiles, canonicalize_smiles, mols_to_nx\n",
    "from moses.metrics.metrics import get_all_metrics\n",
    "from evaluation.stats import eval_graph_list"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Data and preamble"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "device = load_device()\n",
    "device = [0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "config_file = 'sample_zinc250k'\n",
    "seed = 0\n",
    "config = get_config(config_file, seed)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# -------- Load checkpoint --------\n",
    "ckpt_dict = load_ckpt(config, device)\n",
    "configt = ckpt_dict['config']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "load_seed(configt.seed)\n",
    "# train_graph_list, _ = load_data(configt, get_graph_list=True)\n",
    "with open(f'data/{configt.data.data.lower()}_test_nx.pkl', 'rb') as f:\n",
    "    test_graph_list = pickle.load(f)                                   # for NSPDK MMD"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_smiles, test_smiles = load_smiles(configt.data.data)\n",
    "train_smiles, test_smiles = canonicalize_smiles(train_smiles), canonicalize_smiles(test_smiles)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "log_folder_name, log_dir, _ = set_log(configt, is_train=False)\n",
    "log_name = f\"{config.ckpt}-sample-guidance\"\n",
    "logger = Logger(str(os.path.join(log_dir, f'{log_name}.log')), mode='a')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if not check_log(log_folder_name, log_name):\n",
    "    logger.log(f'{log_name}')\n",
    "    start_log(logger, configt)\n",
    "    train_log(logger, configt)\n",
    "sample_log(logger, config)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# -------- Load models --------\n",
    "model_x = load_model_from_ckpt(ckpt_dict['params_x'], ckpt_dict['x_state_dict'], device)\n",
    "model_adj = load_model_from_ckpt(ckpt_dict['params_adj'], ckpt_dict['adj_state_dict'], device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if config.sample.use_ema:\n",
    "    ema_x = load_ema_from_ckpt(model_x, ckpt_dict['ema_x'], configt.train.ema)\n",
    "    ema_adj = load_ema_from_ckpt(model_adj, ckpt_dict['ema_adj'], configt.train.ema)\n",
    "    \n",
    "    ema_x.copy_to(model_x.parameters())\n",
    "    ema_adj.copy_to(model_adj.parameters())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(f'GEN SEED: {config.sample.seed}')\n",
    "load_seed(config.sample.seed)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Generation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sde_x = load_sde(configt.sde.x)\n",
    "sde_adj = load_sde(configt.sde.adj)\n",
    "max_node_num  = configt.data.max_node_num\n",
    "\n",
    "device_id = f'cuda:{device[0]}' if isinstance(device, list) else device"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "continuous = True\n",
    "\n",
    "predictor = config.sampler.predictor\n",
    "corrector = config.sampler.corrector\n",
    "\n",
    "snr=config.sampler.snr\n",
    "scale_eps=config.sampler.scale_eps\n",
    "n_steps=config.sampler.n_steps\n",
    "\n",
    "probability_flow = config.sample.probability_flow\n",
    "eps = config.sample.eps\n",
    "denoise = config.sample.noise_removal"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if configt.data.data in ['QM9']: # Zinc250k gives OOM with 10000 samples generated\n",
    "    batch_size = 10000\n",
    "else:\n",
    "    batch_size = configt.data.batch_size\n",
    "shape_x = (batch_size, max_node_num, configt.data.max_feat_num)\n",
    "shape_adj = (batch_size, max_node_num, max_node_num)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from utils.graph_utils import graphs_to_tensor, node_flags\n",
    "\n",
    "max_node_num = config.data.max_node_num\n",
    "graph_tensor = graphs_to_tensor(test_graph_list, max_node_num)\n",
    "idx = np.random.randint(0, len(test_graph_list), batch_size)\n",
    "idx_sample = 0 # To plot later, we ensure that there are at least n_plot graphs of idx_sample sample\n",
    "n_plot = 5\n",
    "idx = np.array([idx_sample]*n_plot + list(np.random.randint(0, len(test_graph_list), batch_size-n_plot)))\n",
    "init_flags_iter = node_flags(graph_tensor[idx]).to(device_id)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "pct_obs = 0.5\n",
    "same_graph_same_obs = True\n",
    "symmetric = True\n",
    "remove_self_loops = True\n",
    "observed = \"entries\" # \"entries\" \"edges\"\n",
    " \n",
    "ground_truth_adjs = graph_tensor[idx].to(device_id)\n",
    "\n",
    "if observed == \"entries\":\n",
    "    if same_graph_same_obs:\n",
    "        random_samps = torch.rand(len(test_graph_list), max_node_num, max_node_num)[idx].to(device_id)\n",
    "    else:\n",
    "        random_samps = torch.rand(*ground_truth_adjs.shape).to(device_id)\n",
    "    \n",
    "    if symmetric:\n",
    "        random_samps = (random_samps + random_samps.transpose(-1, -2))/2\n",
    "    \n",
    "    bool_tensor = random_samps < pct_obs\n",
    "    if remove_self_loops:\n",
    "        bool_tensor = bool_tensor & ~torch.eye(max_node_num, device=device_id).bool()\n",
    "    idx_observed = torch.where(bool_tensor)\n",
    "elif observed == \"edges\":\n",
    "    idx_edges = torch.where(ground_truth_adjs != 0)\n",
    "    n_edges_tot = idx_edges[0].shape[0]\n",
    "    idx_obs = torch.randperm(n_edges_tot)[:int(pct_obs*n_edges_tot)]\n",
    "    idx_observed = (idx_edges[0][idx_obs], idx_edges[1][idx_obs], idx_edges[2][idx_obs])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "idx_observed = (idx_observed[0].to(device_id), idx_observed[1].to(device_id), idx_observed[2].to(device_id))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "init_x = sde_x.prior_sampling(shape_x).to(device_id)\n",
    "init_adj = sde_adj.prior_sampling_sym(shape_adj).to(device_id)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "init_adj_ground_truth = False\n",
    "if init_adj_ground_truth:\n",
    "    init_adj[idx_observed] = ground_truth_adjs[idx_observed]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def sample(predictor_x, corrector_x,\n",
    "           predictor_adj, corrector_adj,\n",
    "           init_x=None, init_adj=None, flags=None):\n",
    "    with torch.no_grad():\n",
    "        # -------- Initial sample --------\n",
    "        if init_x is not None:\n",
    "            x = init_x.clone()\n",
    "        else:\n",
    "            x = predictor_x.sde.prior_sampling(shape_x).to(device_id)\n",
    "        if init_adj is not None:\n",
    "            adj = init_adj.clone()\n",
    "        else:\n",
    "            adj = predictor_adj.sde.prior_sampling_sym(shape_adj).to(device_id)\n",
    "        \n",
    "        x = mask_x(x, flags)\n",
    "        adj = mask_adjs(adj, flags)\n",
    "        diff_steps = predictor_adj.sde.N\n",
    "        timesteps = torch.linspace(predictor_adj.sde.T, eps, diff_steps, device=device_id)\n",
    "\n",
    "        # -------- Reverse diffusion process --------\n",
    "        for i in trange(0, (diff_steps), desc = '[Sampling]', position = 1, leave=False):\n",
    "            t = timesteps[i]\n",
    "            vec_t = torch.ones(shape_adj[0], device=t.device) * t\n",
    "\n",
    "            _x = x\n",
    "            x, x_mean = corrector_x.update_fn(x, adj, flags, vec_t)\n",
    "            adj, adj_mean = corrector_adj.update_fn(_x, adj, flags, vec_t)\n",
    "            if torch.any(torch.isnan(adj)):\n",
    "                break\n",
    "\n",
    "            _x = x\n",
    "            x, x_mean = predictor_x.update_fn(x, adj, flags, vec_t)\n",
    "            adj, adj_mean = predictor_adj.update_fn(_x, adj, flags, vec_t)\n",
    "    samples_int = quantize_mol(adj)\n",
    "\n",
    "    # adj = torch.nn.functional.one_hot(torch.tensor(samples_int), num_classes=4).permute(0, 3, 1, 2)\n",
    "    x = torch.where(x > 0.5, 1, 0)\n",
    "    x = torch.concat([x, 1 - x.sum(dim=-1, keepdim=True)], dim=-1)      # 32, 9, 4 -> 32, 9, 5\n",
    "    return samples_int, x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_samples(gen_graph_list, gen_adjs, test_graph_list, idx_order, idx_observed, idx_sample, n_graphs, ax=None, layout=None, method=None):\n",
    "\n",
    "    test_graph_sample = nx.convert_node_labels_to_integers(test_graph_list[idx_sample])\n",
    "    n_nodes = test_graph_sample.number_of_nodes()\n",
    "\n",
    "    if layout is None:\n",
    "        layout = nx.drawing.layout.spring_layout(test_graph_sample)\n",
    "        \n",
    "    # Get the adjacency matrix of the test graph\n",
    "    test_adj = graph_tensor[idx_sample].cpu().numpy()\n",
    "    # Extract edges from the test graph\n",
    "    test_edges = list(test_graph_sample.edges())\n",
    "    \n",
    "    graphs_represent = np.nonzero(idx_order == idx_sample)[0]\n",
    "    n_graphs = min(n_graphs, graphs_represent.shape[0])\n",
    "    if graphs_represent.shape[0] > n_graphs:\n",
    "        graphs_represent = graphs_represent[:n_graphs]\n",
    "\n",
    "    if ax is None:\n",
    "        f, ax = plt.subplots(1, 1 + n_graphs, figsize=(15, 5))\n",
    "    \n",
    "        # Draw the test graph\n",
    "        nx.draw(test_graph_sample, pos=layout, ax=ax[0], with_labels=True)\n",
    "        ax[0].set_title(\"Test Graph\")\n",
    "        axs = ax[1:]\n",
    "    else:\n",
    "        axs = ax\n",
    "        \n",
    "    for i in range(n_graphs):\n",
    "        graph_idx = graphs_represent[i]\n",
    "        gen_graph = nx.Graph()# copy.deepcopy(gen_graph_list[graph_idx])\n",
    "        # if gen_graph.number_of_nodes() < n_nodes:\n",
    "        gen_graph.add_nodes_from(np.arange(n_nodes))\n",
    "        # Get the adjacency matrix of the generated graph\n",
    "        # gen_adj = nx.to_numpy_array(gen_graph)\n",
    "        gen_adj = gen_adjs[graph_idx]\n",
    "        np.fill_diagonal(gen_adj, 0.)\n",
    "    \n",
    "        entries_to_check = [\n",
    "            (u.item(), v.item()) for u, v in zip(idx_observed[1][idx_observed[0] == graph_idx], \n",
    "                                   idx_observed[2][idx_observed[0] == graph_idx])\n",
    "                    if u < n_nodes and v < n_nodes\n",
    "        ]\n",
    "    \n",
    "        # Determine edge colors based on adjacency matrix comparison\n",
    "        nonzero_entries = np.nonzero(gen_adj)\n",
    "        gen_graph_edges = []\n",
    "        for j in range(nonzero_entries[0].shape[0]):\n",
    "            if (nonzero_entries[0][j], nonzero_entries[1][j]) not in gen_graph_edges and \\\n",
    "                  (nonzero_entries[1][j], nonzero_entries[0][j]) not in gen_graph_edges and \\\n",
    "                  nonzero_entries[0][j] < n_nodes and nonzero_entries[1][j] < n_nodes:\n",
    "                gen_graph_edges.append((min(nonzero_entries[0][j], nonzero_entries[1][j]), max(nonzero_entries[0][j], nonzero_entries[1][j])))\n",
    "        # gen_graph.add_edges_from(gen_graph_edges)\n",
    "        edge_colors = []\n",
    "        new_edges = []\n",
    "        new_edge_colors = []\n",
    "        n_correct = 0\n",
    "        n_incorrect = 0\n",
    "        for u in range(n_nodes):\n",
    "            for v in range(u + 1, n_nodes):\n",
    "                if (u, v) in entries_to_check or (v, u) in entries_to_check:\n",
    "                    if (u, v) in gen_graph_edges or (v, u) in gen_graph_edges:\n",
    "                        if test_adj[u, v] == gen_adj[u, v] or test_adj[v, u] == gen_adj[v, u]:  # Matching edge\n",
    "                            edge_colors.append('green')\n",
    "                            n_correct += 1\n",
    "                        else:  # Non-matching edge\n",
    "                            edge_colors.append('red')\n",
    "                            n_incorrect += 1\n",
    "                    else:\n",
    "                        new_edges.append((u, v))\n",
    "                        # gen_graph.add_edge(u, v)\n",
    "                        if test_adj[u, v] == gen_adj[u, v] or test_adj[v, u] == gen_adj[v, u]:\n",
    "                            new_edge_colors.append('green')\n",
    "                            n_correct += 1\n",
    "                        else:\n",
    "                            new_edge_colors.append('red')\n",
    "                            n_incorrect += 1\n",
    "                elif (u, v) in gen_graph_edges or (v, u) in gen_graph_edges:\n",
    "                    edge_colors.append('black')\n",
    "        print(f\"Graph {i + 1}: {n_correct} correct edges, {n_incorrect} incorrect edges\")\n",
    "        assert len(edge_colors) == len(gen_graph_edges), f\"{edge_colors} {gen_graph_edges}\"\n",
    "        # Draw the generated graph with colored edges\n",
    "        nx.draw_networkx_nodes(gen_graph, pos=layout, ax=axs[i])\n",
    "        nx.draw_networkx_labels(gen_graph, pos=layout, ax=axs[i])\n",
    "        nx.draw_networkx_edges(gen_graph, pos=layout, ax=axs[i], edgelist=new_edges, edge_color=new_edge_colors, style='dotted')\n",
    "        nx.draw_networkx_edges(gen_graph, pos=layout, ax=axs[i], edgelist=gen_graph_edges, edge_color=edge_colors)\n",
    "        # ax[i].set_title(f\"Generated Graph {i + 1}\")\n",
    "    if method is not None:\n",
    "        axs[0].set_ylabel(method, fontsize=24)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Greedy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "guidance_config = load_yaml_config(f'config_guidance/incomplete/{observed}/greedy.yaml')\n",
    "guidance_args = edict({'method': 'greedy', 'obj': guidance_config['obj'], **guidance_config[configt.data.data.lower()]})\n",
    "guidance_args['loss_kwargs'] = {'idx_obs': idx_observed, 'true_adj': ground_truth_adjs}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "score_fn_x = get_score_fn(sde_x, model_x, train=False, continuous=continuous)\n",
    "score_fn_adj = get_score_fn(sde_adj, model_adj, train=False, continuous=continuous)\n",
    "\n",
    "predictor_fn = ReverseDiffusionPredictor if predictor=='Reverse' else EulerMaruyamaPredictor \n",
    "corrector_fn = LangevinCorrector if corrector=='Langevin' else NoneCorrector\n",
    "\n",
    "predictor_obj_x = predictor_fn('x', sde_x, score_fn_x, probability_flow)\n",
    "corrector_obj_x = corrector_fn('x', sde_x, score_fn_x, snr, scale_eps, n_steps)\n",
    "\n",
    "predictor_obj_adj = predictor_fn('adj', sde_adj, score_fn_adj, probability_flow, guidance_args=guidance_args)\n",
    "corrector_obj_adj = corrector_fn('adj', sde_adj, score_fn_adj, snr, scale_eps, n_steps)\n",
    "\n",
    "adj_greedy, x_greedy = sample(predictor_obj_x, corrector_obj_x, predictor_obj_adj, corrector_obj_adj, init_x, init_adj, init_flags_iter)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "adj_greedy_mod = adj_greedy.copy() - 1\n",
    "adj_greedy_mod[adj_greedy_mod == -1] = 3      # 0, 1, 2, 3 (no, S, D, T) -> 3, 0, 1, 2\n",
    "adj_onehot_greedy = torch.nn.functional.one_hot(torch.tensor(adj_greedy_mod), num_classes=4).permute(0, 3, 1, 2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "gen_mols, num_mols_wo_correction = gen_mol(x_greedy, adj_onehot_greedy, configt.data.data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "gen_graph_list_greedy = mols_to_nx(gen_mols)\n",
    "nx.draw(gen_graph_list_greedy[0], with_labels=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "gen_smiles = mols_to_smiles(gen_mols)\n",
    "gen_smiles = [smi for smi in gen_smiles if len(smi)]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "scores_greedy = get_all_metrics(gen=gen_smiles, k=len(gen_smiles), device=device_id, n_jobs=8, test=test_smiles)\n",
    "scores_nspdk_greedy = eval_graph_list(test_graph_list, gen_graph_list_greedy, methods=['nspdk'])['nspdk']\n",
    "scores_greedy, scores_nspdk_greedy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "acc_greedy = (torch.tensor(adj_greedy).to(device_id)[idx_observed] == ground_truth_adjs[idx_observed]).sum() / idx_observed[0].shape[0]\n",
    "acc_greedy = acc_greedy.item()\n",
    "acc_greedy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_samples(gen_graph_list_greedy, adj_greedy, test_graph_list, idx, idx_observed, idx_sample, n_plot)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Loss"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "guidance_config = load_yaml_config(f'config_guidance/incomplete/{observed}/loss.yaml')\n",
    "guidance_args = edict({'method': 'loss', 'obj': guidance_config['obj'], **guidance_config[configt.data.data.lower()]})\n",
    "guidance_args['loss_kwargs'] = {'idx_obs': idx_observed, 'true_adj': ground_truth_adjs}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "score_fn_x = get_score_fn(sde_x, model_x, train=False, continuous=continuous)\n",
    "score_fn_adj = get_score_fn(sde_adj, model_adj, train=False, continuous=continuous)\n",
    "\n",
    "predictor_fn = ReverseDiffusionPredictor if predictor=='Reverse' else EulerMaruyamaPredictor \n",
    "corrector_fn = LangevinCorrector if corrector=='Langevin' else NoneCorrector\n",
    "\n",
    "predictor_obj_x = predictor_fn('x', sde_x, score_fn_x, probability_flow)\n",
    "corrector_obj_x = corrector_fn('x', sde_x, score_fn_x, snr, scale_eps, n_steps)\n",
    "\n",
    "predictor_obj_adj = predictor_fn('adj', sde_adj, score_fn_adj, probability_flow, guidance_args=guidance_args)\n",
    "corrector_obj_adj = corrector_fn('adj', sde_adj, score_fn_adj, snr, scale_eps, n_steps)\n",
    "\n",
    "adj_loss, x_loss = sample(predictor_obj_x, corrector_obj_x, predictor_obj_adj, corrector_obj_adj, init_x, init_adj, init_flags_iter)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "adj_loss_mod = adj_loss.copy() - 1\n",
    "adj_loss_mod[adj_loss_mod == -1] = 3      # 0, 1, 2, 3 (no, S, D, T) -> 3, 0, 1, 2\n",
    "adj_onehot_loss = torch.nn.functional.one_hot(torch.tensor(adj_loss_mod), num_classes=4).permute(0, 3, 1, 2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "gen_mols, num_mols_wo_correction = gen_mol(x_loss, adj_onehot_loss, configt.data.data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "gen_graph_list_loss = mols_to_nx(gen_mols)\n",
    "nx.draw(gen_graph_list_loss[0], with_labels=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "gen_smiles = mols_to_smiles(gen_mols)\n",
    "gen_smiles = [smi for smi in gen_smiles if len(smi)]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "scores_loss = get_all_metrics(gen=gen_smiles, k=len(gen_smiles), device=device_id, n_jobs=8, test=test_smiles)\n",
    "scores_nspdk_loss = eval_graph_list(test_graph_list, gen_graph_list_loss, methods=['nspdk'])['nspdk']\n",
    "scores_loss"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "acc_loss = (torch.tensor(adj_loss).to(device_id)[idx_observed] == ground_truth_adjs[idx_observed]).sum() / idx_observed[0].shape[0]\n",
    "acc_loss = acc_loss.item()\n",
    "acc_loss"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_samples(gen_graph_list_loss, adj_loss, test_graph_list, idx, idx_observed, idx_sample, n_plot)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Zero"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "guidance_config = load_yaml_config(f'config_guidance/incomplete/{observed}/zero.yaml')\n",
    "guidance_args = edict({'method': 'zero', 'obj': guidance_config['obj'], **guidance_config[configt.data.data.lower()]})\n",
    "guidance_args['loss_kwargs'] = {'idx_obs': idx_observed, 'true_adj': ground_truth_adjs}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "guidance_args"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "score_fn_x = get_score_fn(sde_x, model_x, train=False, continuous=continuous)\n",
    "score_fn_adj = get_score_fn(sde_adj, model_adj, train=False, continuous=continuous)\n",
    "\n",
    "predictor_fn = ReverseDiffusionPredictor if predictor=='Reverse' else EulerMaruyamaPredictor \n",
    "corrector_fn = LangevinCorrector if corrector=='Langevin' else NoneCorrector\n",
    "\n",
    "predictor_obj_x = predictor_fn('x', sde_x, score_fn_x, probability_flow)\n",
    "corrector_obj_x = corrector_fn('x', sde_x, score_fn_x, snr, scale_eps, n_steps)\n",
    "\n",
    "predictor_obj_adj = predictor_fn('adj', sde_adj, score_fn_adj, probability_flow, guidance_args=guidance_args)\n",
    "corrector_obj_adj = corrector_fn('adj', sde_adj, score_fn_adj, snr, scale_eps, n_steps)\n",
    "\n",
    "adj_zero, x_zero = sample(predictor_obj_x, corrector_obj_x, predictor_obj_adj, corrector_obj_adj, init_x, init_adj, init_flags_iter)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "adj_zero_mod = adj_zero.copy() - 1\n",
    "adj_zero_mod[adj_zero_mod == -1] = 3      # 0, 1, 2, 3 (no, S, D, T) -> 3, 0, 1, 2\n",
    "adj_onehot_zero = torch.nn.functional.one_hot(torch.tensor(adj_zero_mod), num_classes=4).permute(0, 3, 1, 2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "gen_mols, num_mols_wo_correction = gen_mol(x_zero, adj_onehot_zero, configt.data.data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "gen_graph_list_zero = mols_to_nx(gen_mols)\n",
    "nx.draw(gen_graph_list_zero[0], with_labels=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "gen_smiles = mols_to_smiles(gen_mols)\n",
    "gen_smiles = [smi for smi in gen_smiles if len(smi)]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "scores_zero = get_all_metrics(gen=gen_smiles, k=len(gen_smiles), device=device_id, n_jobs=8, test=test_smiles)\n",
    "scores_nspdk_zero = eval_graph_list(test_graph_list, gen_graph_list_zero, methods=['nspdk'])['nspdk']\n",
    "scores_zero, scores_nspdk_zero"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "acc_zero = (torch.tensor(adj_zero).to(device_id)[idx_observed] == ground_truth_adjs[idx_observed]).sum() / idx_observed[0].shape[0]\n",
    "acc_zero = acc_zero.item()\n",
    "acc_zero"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_samples(gen_graph_list_zero, adj_zero, test_graph_list, idx, idx_observed, idx_sample, n_plot)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Unconstrained"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "score_fn_x = get_score_fn(sde_x, model_x, train=False, continuous=continuous)\n",
    "score_fn_adj = get_score_fn(sde_adj, model_adj, train=False, continuous=continuous)\n",
    "\n",
    "predictor_fn = ReverseDiffusionPredictor if predictor=='Reverse' else EulerMaruyamaPredictor \n",
    "corrector_fn = LangevinCorrector if corrector=='Langevin' else NoneCorrector\n",
    "\n",
    "predictor_obj_x = predictor_fn('x', sde_x, score_fn_x, probability_flow)\n",
    "corrector_obj_x = corrector_fn('x', sde_x, score_fn_x, snr, scale_eps, n_steps)\n",
    "\n",
    "predictor_obj_adj = predictor_fn('adj', sde_adj, score_fn_adj, probability_flow)\n",
    "corrector_obj_adj = corrector_fn('adj', sde_adj, score_fn_adj, snr, scale_eps, n_steps)\n",
    "\n",
    "adj_uncons, x_uncons = sample(predictor_obj_x, corrector_obj_x, predictor_obj_adj, corrector_obj_adj, init_x, init_adj, init_flags_iter)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "adj_uncons_mod = adj_uncons.copy() - 1\n",
    "adj_uncons_mod[adj_uncons_mod == -1] = 3      # 0, 1, 2, 3 (no, S, D, T) -> 3, 0, 1, 2\n",
    "adj_onehot_uncons = torch.nn.functional.one_hot(torch.tensor(adj_uncons_mod), num_classes=4).permute(0, 3, 1, 2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "gen_mols, num_mols_wo_correction = gen_mol(x_uncons, adj_onehot_uncons, configt.data.data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "gen_graph_list_uncons = mols_to_nx(gen_mols)\n",
    "nx.draw(gen_graph_list_uncons[0], with_labels=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "gen_smiles = mols_to_smiles(gen_mols)\n",
    "gen_smiles = [smi for smi in gen_smiles if len(smi)]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "scores_uncons = get_all_metrics(gen=gen_smiles, k=len(gen_smiles), device=device_id, n_jobs=8, test=test_smiles)\n",
    "scores_nspdk_uncons = eval_graph_list(test_graph_list, gen_graph_list_uncons, methods=['nspdk'])['nspdk']\n",
    "scores_uncons, scores_nspdk_uncons"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "acc_uncons = (torch.tensor(adj_uncons).to(device_id)[idx_observed] == ground_truth_adjs[idx_observed]).sum() / idx_observed[0].shape[0]#, adj_uncons[idx_observed].shape\n",
    "acc_uncons = acc_uncons.item()\n",
    "acc_uncons"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_samples(gen_graph_list_uncons, adj_uncons, test_graph_list, idx, idx_observed, idx_sample, n_plot)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Evaluation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "f, ax = plt.subplots(4, 5, figsize=(20,12))\n",
    "\n",
    "for i in range(5):\n",
    "    j = np.random.randint(len(gen_graph_list_greedy))\n",
    "    nx.draw(gen_graph_list_greedy[j], ax=ax[0, i], node_size=10)\n",
    "    nx.draw(gen_graph_list_loss[j], ax=ax[1, i], node_size=10)\n",
    "    nx.draw(gen_graph_list_zero[j], ax=ax[2, i], node_size=10)\n",
    "    nx.draw(gen_graph_list_uncons[j], ax=ax[3, i], node_size=10)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_graph_sample = nx.convert_node_labels_to_integers(test_graph_list[idx_sample])\n",
    "n_nodes = test_graph_sample.number_of_nodes()\n",
    "\n",
    "layout = nx.drawing.layout.spring_layout(test_graph_sample)\n",
    "\n",
    "f = plt.figure(figsize=(20, 15))\n",
    "    \n",
    "# Draw the test graph\n",
    "ax_sample = f.add_subplot(3,6,7)\n",
    "nx.draw(test_graph_sample, pos=layout, ax=ax_sample, with_labels=True)\n",
    "ax_sample.set_xlabel(\"Sample graph\", fontsize=24)\n",
    "\n",
    "gen_graph_lists = [gen_graph_list_loss, gen_graph_list_greedy, gen_graph_list_zero, gen_graph_list_uncons]\n",
    "adjs_lists = [adj_loss, adj_greedy, adj_zero, adj_uncons]\n",
    "methods = [\"GGDiff-G\", \"GGDiff-C\", \"GGDiff-Z\", \"Uncons.\"]\n",
    "\n",
    "for i in range(4):\n",
    "    axs = [f.add_subplot(4,6,6*i+j+2) for j in range(5)]\n",
    "    plot_samples(gen_graph_lists[i], adjs_lists[i], test_graph_list, idx, idx_observed, idx_sample, n_plot, ax=axs, layout=layout, method=methods[i])\n",
    "\n",
    "if not os.path.exists('results/incomplete_graph_gen'):\n",
    "    os.makedirs('results/incomplete_graph_gen')\n",
    "f.savefig(f'results/incomplete_graph_gen/{configt.data.data}-{observed}.pdf', bbox_inches='tight')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df_results = pd.DataFrame([scores_greedy, scores_loss, scores_zero, scores_uncons])\n",
    "df_results['Accuracy'] = [acc_greedy, acc_loss, acc_zero, acc_uncons]\n",
    "df_results.index = ['Greedy', 'Loss', 'Zero', 'Unconstrained']\n",
    "df_results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(\"\\\\begin{table}[h]\")\n",
    "print(\"\\\\centering\")\n",
    "print(\"\\\\caption{Results for the incomplete graph generation experiment.}\")\n",
    "print(\"\\\\label{tab:incomplete_graph_gen}\")\n",
    "print(\"\\\\begin{tabular}{cccc}\")\n",
    "print(\"\\\\toprule\")\n",
    "print(\"\\\\textbf{Method} & \\\\textbf{Accuracy} & \\\\textbf{Pct Unique} \\\\\\\\\")\n",
    "print(\"\\\\midrule\")\n",
    "for index, row in df_results.iterrows():\n",
    "    print(f\"{index} & {100*row['Accuracy']:.2f} & {100*row['unique@'+str(batch_size)]:.2f} \\\\\\\\\")\n",
    "print(\"\\\\bottomrule\")\n",
    "print(\"\\\\end{tabular}\")\n",
    "print(\"\\\\end{table}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "file_path = f'results/incomplete_graph_gen/{configt.data.data}-{observed}.pkl'\n",
    "with open(file_path, 'wb') as f:\n",
    "    pickle.dump(df_results, f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "gdss_env",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
