{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import torch\n",
    "import networkx as nx\n",
    "import matplotlib.pyplot as plt\n",
    "from easydict import EasyDict as edict\n",
    "\n",
    "from tqdm.notebook import trange\n",
    "import numpy as np\n",
    "import pickle\n",
    "\n",
    "\n",
    "import sys\n",
    "sys.path.append('..')\n",
    "\n",
    "from parsers.config import get_config\n",
    "\n",
    "from losses import get_score_fn\n",
    "from solver_guidance import ReverseDiffusionPredictor, EulerMaruyamaPredictor, LangevinCorrector, NoneCorrector\n",
    "from utils.graph_utils import mask_adjs, mask_x\n",
    "from utils.logger import Logger, set_log, start_log, train_log, sample_log, check_log\n",
    "from utils.loader import load_ckpt, load_seed, load_device, load_model_from_ckpt, \\\n",
    "                         load_ema_from_ckpt, load_sde, load_yaml_config\n",
    "from utils.graph_utils import init_flags, quantize_mol\n",
    "from utils.mol_utils import gen_mol, mols_to_smiles, load_smiles, canonicalize_smiles, mols_to_nx\n",
    "from moses.metrics.metrics import get_all_metrics\n",
    "from evaluation.stats import eval_graph_list"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Data and preamble"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "device = load_device()\n",
    "device = [0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "config_file = 'sample_qm9'\n",
    "seed = 0\n",
    "config = get_config(config_file, seed)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# -------- Load checkpoint --------\n",
    "ckpt_dict = load_ckpt(config, device)\n",
    "configt = ckpt_dict['config']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "load_seed(configt.seed)\n",
    "# train_graph_list, _ = load_data(configt, get_graph_list=True)\n",
    "with open(f'data/{configt.data.data.lower()}_test_nx.pkl', 'rb') as f:\n",
    "    test_graph_list = pickle.load(f)                                   # for NSPDK MMD"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_smiles, test_smiles = load_smiles(configt.data.data)\n",
    "train_smiles, test_smiles = canonicalize_smiles(train_smiles), canonicalize_smiles(test_smiles)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "log_folder_name, log_dir, _ = set_log(configt, is_train=False)\n",
    "log_name = f\"{config.ckpt}-sample-guidance\"\n",
    "logger = Logger(str(os.path.join(log_dir, f'{log_name}.log')), mode='a')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if not check_log(log_folder_name, log_name):\n",
    "    logger.log(f'{log_name}')\n",
    "    start_log(logger, configt)\n",
    "    train_log(logger, configt)\n",
    "sample_log(logger, config)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# -------- Load models --------\n",
    "model_x = load_model_from_ckpt(ckpt_dict['params_x'], ckpt_dict['x_state_dict'], device)\n",
    "model_adj = load_model_from_ckpt(ckpt_dict['params_adj'], ckpt_dict['adj_state_dict'], device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if config.sample.use_ema:\n",
    "    ema_x = load_ema_from_ckpt(model_x, ckpt_dict['ema_x'], configt.train.ema)\n",
    "    ema_adj = load_ema_from_ckpt(model_adj, ckpt_dict['ema_adj'], configt.train.ema)\n",
    "    \n",
    "    ema_x.copy_to(model_x.parameters())\n",
    "    ema_adj.copy_to(model_adj.parameters())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(f'GEN SEED: {config.sample.seed}')\n",
    "load_seed(config.sample.seed)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Generation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sde_x = load_sde(configt.sde.x)\n",
    "sde_adj = load_sde(configt.sde.adj)\n",
    "max_node_num  = configt.data.max_node_num\n",
    "\n",
    "device_id = f'cuda:{device[0]}' if isinstance(device, list) else device"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "continuous = True\n",
    "\n",
    "predictor = config.sampler.predictor\n",
    "corrector = config.sampler.corrector\n",
    "\n",
    "snr=config.sampler.snr\n",
    "scale_eps=config.sampler.scale_eps\n",
    "n_steps=config.sampler.n_steps\n",
    "\n",
    "probability_flow = config.sample.probability_flow\n",
    "eps = config.sample.eps\n",
    "denoise = config.sample.noise_removal"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if configt.data.data in ['QM9']: # Zinc250k gives OOM with 10000 samples generated\n",
    "    batch_size = 10000\n",
    "else:\n",
    "    batch_size = configt.data.batch_size\n",
    "shape_x = (batch_size, max_node_num, configt.data.max_feat_num)\n",
    "shape_adj = (batch_size, max_node_num, max_node_num)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from utils.graph_utils import graphs_to_tensor, node_flags\n",
    "\n",
    "max_node_num = config.data.max_node_num\n",
    "graph_tensor = graphs_to_tensor(test_graph_list, max_node_num)\n",
    "idx = np.random.randint(0, len(test_graph_list), batch_size)\n",
    "idx_sample = 0 # To plot later, we ensure that there are at least n_plot graphs of idx_sample sample\n",
    "n_plot = 5\n",
    "idx = np.array([idx_sample]*n_plot + list(np.random.randint(0, len(test_graph_list), batch_size-n_plot)))\n",
    "init_flags_iter = node_flags(graph_tensor[idx]).to(device_id)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "pct_obs = 0.5\n",
    "same_graph_same_obs = True\n",
    "symmetric = True\n",
    "remove_self_loops = True\n",
    "observed = \"entries\" # \"entries\" \"edges\"\n",
    " \n",
    "ground_truth_adjs = graph_tensor[idx].to(device_id)\n",
    "\n",
    "if observed == \"entries\":\n",
    "    if same_graph_same_obs:\n",
    "        random_samps = torch.rand(len(test_graph_list), max_node_num, max_node_num)[idx].to(device_id)\n",
    "    else:\n",
    "        random_samps = torch.rand(*ground_truth_adjs.shape).to(device_id)\n",
    "    \n",
    "    if symmetric:\n",
    "        random_samps = (random_samps + random_samps.transpose(-1, -2))/2\n",
    "    \n",
    "    bool_tensor = random_samps < pct_obs\n",
    "    if remove_self_loops:\n",
    "        bool_tensor = bool_tensor & ~torch.eye(max_node_num, device=device_id).bool()\n",
    "    idx_observed = torch.where(bool_tensor)\n",
    "elif observed == \"edges\":\n",
    "    idx_edges = torch.where(ground_truth_adjs != 0)\n",
    "    n_edges_tot = idx_edges[0].shape[0]\n",
    "    idx_obs = torch.randperm(n_edges_tot)[:int(pct_obs*n_edges_tot)]\n",
    "    idx_observed = (idx_edges[0][idx_obs], idx_edges[1][idx_obs], idx_edges[2][idx_obs])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "idx_observed = (idx_observed[0].to(device_id), idx_observed[1].to(device_id), idx_observed[2].to(device_id))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "init_x = sde_x.prior_sampling(shape_x).to(device_id)\n",
    "init_adj = sde_adj.prior_sampling_sym(shape_adj).to(device_id)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "init_adj_ground_truth = False\n",
    "if init_adj_ground_truth:\n",
    "    init_adj[idx_observed] = ground_truth_adjs[idx_observed]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def sample(predictor_x, corrector_x,\n",
    "           predictor_adj, corrector_adj,\n",
    "           init_x=None, init_adj=None, flags=None):\n",
    "    with torch.no_grad():\n",
    "        # -------- Initial sample --------\n",
    "        if init_x is not None:\n",
    "            x = init_x.clone()\n",
    "        else:\n",
    "            x = predictor_x.sde.prior_sampling(shape_x).to(device_id)\n",
    "        if init_adj is not None:\n",
    "            adj = init_adj.clone()\n",
    "        else:\n",
    "            adj = predictor_adj.sde.prior_sampling_sym(shape_adj).to(device_id)\n",
    "        \n",
    "        x = mask_x(x, flags)\n",
    "        adj = mask_adjs(adj, flags)\n",
    "        diff_steps = predictor_adj.sde.N\n",
    "        timesteps = torch.linspace(predictor_adj.sde.T, eps, diff_steps, device=device_id)\n",
    "\n",
    "        # -------- Reverse diffusion process --------\n",
    "        for i in trange(0, (diff_steps), desc = '[Sampling]', position = 1, leave=False):\n",
    "            t = timesteps[i]\n",
    "            vec_t = torch.ones(shape_adj[0], device=t.device) * t\n",
    "\n",
    "            _x = x\n",
    "            x, x_mean = corrector_x.update_fn(x, adj, flags, vec_t)\n",
    "            adj, adj_mean = corrector_adj.update_fn(_x, adj, flags, vec_t)\n",
    "            if torch.any(torch.isnan(adj)):\n",
    "                break\n",
    "\n",
    "            _x = x\n",
    "            x, x_mean = predictor_x.update_fn(x, adj, flags, vec_t)\n",
    "            adj, adj_mean = predictor_adj.update_fn(_x, adj, flags, vec_t)\n",
    "    samples_int = quantize_mol(adj)\n",
    "\n",
    "    # adj = torch.nn.functional.one_hot(torch.tensor(samples_int), num_classes=4).permute(0, 3, 1, 2)\n",
    "    x = torch.where(x > 0.5, 1, 0)\n",
    "    x = torch.concat([x, 1 - x.sum(dim=-1, keepdim=True)], dim=-1)      # 32, 9, 4 -> 32, 9, 5\n",
    "    return samples_int, x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_samples(gen_graph_list, gen_adjs, test_graph_list, idx_order, idx_observed, idx_sample, n_graphs, ax=None, layout=None, method=None):\n",
    "\n",
    "    test_graph_sample = nx.convert_node_labels_to_integers(test_graph_list[idx_sample])\n",
    "    n_nodes = test_graph_sample.number_of_nodes()\n",
    "\n",
    "    if layout is None:\n",
    "        layout = nx.drawing.layout.spring_layout(test_graph_sample)\n",
    "        \n",
    "    # Get the adjacency matrix of the test graph\n",
    "    test_adj = graph_tensor[idx_sample].cpu().numpy()\n",
    "    # Extract edges from the test graph\n",
    "    test_edges = list(test_graph_sample.edges())\n",
    "    \n",
    "    graphs_represent = np.nonzero(idx_order == idx_sample)[0]\n",
    "    n_graphs = min(n_graphs, graphs_represent.shape[0])\n",
    "    if graphs_represent.shape[0] > n_graphs:\n",
    "        graphs_represent = graphs_represent[:n_graphs]\n",
    "\n",
    "    if ax is None:\n",
    "        f, ax = plt.subplots(1, 1 + n_graphs, figsize=(15, 5))\n",
    "    \n",
    "        # Draw the test graph\n",
    "        nx.draw(test_graph_sample, pos=layout, ax=ax[0], with_labels=True)\n",
    "        ax[0].set_title(\"Test Graph\")\n",
    "        axs = ax[1:]\n",
    "    else:\n",
    "        axs = ax\n",
    "        \n",
    "    for i in range(n_graphs):\n",
    "        graph_idx = graphs_represent[i]\n",
    "        gen_graph = nx.Graph()# copy.deepcopy(gen_graph_list[graph_idx])\n",
    "        # if gen_graph.number_of_nodes() < n_nodes:\n",
    "        gen_graph.add_nodes_from(np.arange(n_nodes))\n",
    "        # Get the adjacency matrix of the generated graph\n",
    "        # gen_adj = nx.to_numpy_array(gen_graph)\n",
    "        gen_adj = gen_adjs[graph_idx]\n",
    "        np.fill_diagonal(gen_adj, 0.)\n",
    "    \n",
    "        entries_to_check = [\n",
    "            (u.item(), v.item()) for u, v in zip(idx_observed[1][idx_observed[0] == graph_idx], \n",
    "                                   idx_observed[2][idx_observed[0] == graph_idx])\n",
    "                    if u < n_nodes and v < n_nodes\n",
    "        ]\n",
    "    \n",
    "        # Determine edge colors based on adjacency matrix comparison\n",
    "        nonzero_entries = np.nonzero(gen_adj)\n",
    "        gen_graph_edges = []\n",
    "        for j in range(nonzero_entries[0].shape[0]):\n",
    "            if (nonzero_entries[0][j], nonzero_entries[1][j]) not in gen_graph_edges and \\\n",
    "                  (nonzero_entries[1][j], nonzero_entries[0][j]) not in gen_graph_edges and \\\n",
    "                  nonzero_entries[0][j] < n_nodes and nonzero_entries[1][j] < n_nodes:\n",
    "                gen_graph_edges.append((min(nonzero_entries[0][j], nonzero_entries[1][j]), max(nonzero_entries[0][j], nonzero_entries[1][j])))\n",
    "        # gen_graph.add_edges_from(gen_graph_edges)\n",
    "        edge_colors = []\n",
    "        new_edges = []\n",
    "        new_edge_colors = []\n",
    "        n_correct = 0\n",
    "        n_incorrect = 0\n",
    "        for u in range(n_nodes):\n",
    "            for v in range(u + 1, n_nodes):\n",
    "                if (u, v) in entries_to_check or (v, u) in entries_to_check:\n",
    "                    if (u, v) in gen_graph_edges or (v, u) in gen_graph_edges:\n",
    "                        if test_adj[u, v] == gen_adj[u, v] or test_adj[v, u] == gen_adj[v, u]:  # Matching edge\n",
    "                            edge_colors.append('green')\n",
    "                            n_correct += 1\n",
    "                        else:  # Non-matching edge\n",
    "                            edge_colors.append('red')\n",
    "                            n_incorrect += 1\n",
    "                    else:\n",
    "                        new_edges.append((u, v))\n",
    "                        # gen_graph.add_edge(u, v)\n",
    "                        if test_adj[u, v] == gen_adj[u, v] or test_adj[v, u] == gen_adj[v, u]:\n",
    "                            new_edge_colors.append('green')\n",
    "                            n_correct += 1\n",
    "                        else:\n",
    "                            new_edge_colors.append('red')\n",
    "                            n_incorrect += 1\n",
    "                elif (u, v) in gen_graph_edges or (v, u) in gen_graph_edges:\n",
    "                    edge_colors.append('black')\n",
    "        print(f\"Graph {i + 1}: {n_correct} correct edges, {n_incorrect} incorrect edges\")\n",
    "        assert len(edge_colors) == len(gen_graph_edges), f\"{edge_colors} {gen_graph_edges}\"\n",
    "        # Draw the generated graph with colored edges\n",
    "        nx.draw_networkx_nodes(gen_graph, pos=layout, ax=axs[i])\n",
    "        nx.draw_networkx_labels(gen_graph, pos=layout, ax=axs[i])\n",
    "        nx.draw_networkx_edges(gen_graph, pos=layout, ax=axs[i], edgelist=new_edges, edge_color=new_edge_colors, style='dotted')\n",
    "        nx.draw_networkx_edges(gen_graph, pos=layout, ax=axs[i], edgelist=gen_graph_edges, edge_color=edge_colors)\n",
    "        # ax[i].set_title(f\"Generated Graph {i + 1}\")\n",
    "    if method is not None:\n",
    "        axs[0].set_ylabel(method, fontsize=24)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Greedy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "N_vals = [1] + list(range(2, 21, 2))\n",
    "N_vals"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "guidance_config = load_yaml_config(f'config_guidance/incomplete/{observed}/greedy.yaml')\n",
    "guidance_args = edict({'method': 'greedy', 'obj': guidance_config['obj'], **guidance_config[configt.data.data.lower()]})\n",
    "guidance_args['loss_kwargs'] = {'idx_obs': idx_observed, 'true_adj': ground_truth_adjs}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "score_fn_x = get_score_fn(sde_x, model_x, train=False, continuous=continuous)\n",
    "score_fn_adj = get_score_fn(sde_adj, model_adj, train=False, continuous=continuous)\n",
    "\n",
    "predictor_fn = ReverseDiffusionPredictor if predictor=='Reverse' else EulerMaruyamaPredictor \n",
    "corrector_fn = LangevinCorrector if corrector=='Langevin' else NoneCorrector\n",
    "\n",
    "predictor_obj_x = predictor_fn('x', sde_x, score_fn_x, probability_flow)\n",
    "corrector_obj_x = corrector_fn('x', sde_x, score_fn_x, snr, scale_eps, n_steps)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "accs_greedy = {}\n",
    "scores_greedy = {}\n",
    "scores_nspdk_greedy = {}\n",
    "\n",
    "for n in N_vals:\n",
    "    guidance_args['n_traj'] = n\n",
    "    predictor_obj_adj = predictor_fn('adj', sde_adj, score_fn_adj, probability_flow, guidance_args=guidance_args)\n",
    "    corrector_obj_adj = corrector_fn('adj', sde_adj, score_fn_adj, snr, scale_eps, n_steps)\n",
    "\n",
    "    adj, x = sample(predictor_obj_x, corrector_obj_x, predictor_obj_adj, corrector_obj_adj, init_x, init_adj, init_flags_iter)\n",
    "    adj_mod = adj.copy() - 1\n",
    "    adj_mod[adj_mod == -1] = 3      # 0, 1, 2, 3 (no, S, D, T) -> 3, 0, 1, 2\n",
    "    adj_onehot = torch.nn.functional.one_hot(torch.tensor(adj_mod), num_classes=4).permute(0, 3, 1, 2)\n",
    "    \n",
    "    gen_mols, num_mols_wo_correction = gen_mol(x, adj_onehot, configt.data.data)\n",
    "\n",
    "    gen_graph_list = mols_to_nx(gen_mols)\n",
    "\n",
    "    gen_smiles = mols_to_smiles(gen_mols)\n",
    "    gen_smiles = [smi for smi in gen_smiles if len(smi)]\n",
    "\n",
    "    scores = get_all_metrics(gen=gen_smiles, k=len(gen_smiles), device=device_id, n_jobs=8, test=test_smiles)\n",
    "    scores_nspdk = eval_graph_list(test_graph_list, gen_graph_list, methods=['nspdk'])['nspdk']\n",
    "\n",
    "    acc = (torch.tensor(adj).to(device_id)[idx_observed] == ground_truth_adjs[idx_observed]).sum() / idx_observed[0].shape[0]\n",
    "    acc = acc.item()\n",
    "    \n",
    "    adj_tensor = torch.tensor(adj).to(device_id)\n",
    "        \n",
    "    accs_greedy[n] = acc\n",
    "    scores_greedy[n] = scores\n",
    "    scores_nspdk_greedy[n] = scores_nspdk"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Zero"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "guidance_config = load_yaml_config(f'config_guidance/incomplete/{observed}/zero.yaml')\n",
    "guidance_args = edict({'method': 'zero', 'obj': guidance_config['obj'], **guidance_config[configt.data.data.lower()]})\n",
    "guidance_args['loss_kwargs'] = {'idx_obs': idx_observed, 'true_adj': ground_truth_adjs}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "accs_zero = {}\n",
    "scores_zero = {}\n",
    "scores_nspdk_zero = {}\n",
    "\n",
    "for n in N_vals:\n",
    "    guidance_args['n_traj'] = n\n",
    "    predictor_obj_adj = predictor_fn('adj', sde_adj, score_fn_adj, probability_flow, guidance_args=guidance_args)\n",
    "    corrector_obj_adj = corrector_fn('adj', sde_adj, score_fn_adj, snr, scale_eps, n_steps)\n",
    "\n",
    "    adj, x = sample(predictor_obj_x, corrector_obj_x, predictor_obj_adj, corrector_obj_adj, init_x, init_adj, init_flags_iter)\n",
    "    if adj.min() < -1000:\n",
    "        accs_zero[n] = -1.\n",
    "        continue\n",
    "    adj_mod = adj.copy() - 1\n",
    "    adj_mod[adj_mod == -1] = 3      # 0, 1, 2, 3 (no, S, D, T) -> 3, 0, 1, 2\n",
    "    adj_onehot = torch.nn.functional.one_hot(torch.tensor(adj_mod), num_classes=4).permute(0, 3, 1, 2)\n",
    "    \n",
    "    gen_mols, num_mols_wo_correction = gen_mol(x, adj_onehot, configt.data.data)\n",
    "\n",
    "    gen_graph_list = mols_to_nx(gen_mols)\n",
    "\n",
    "    gen_smiles = mols_to_smiles(gen_mols)\n",
    "    gen_smiles = [smi for smi in gen_smiles if len(smi)]\n",
    "\n",
    "    scores = get_all_metrics(gen=gen_smiles, k=len(gen_smiles), device=device_id, n_jobs=8, test=test_smiles)\n",
    "    scores_nspdk = eval_graph_list(test_graph_list, gen_graph_list, methods=['nspdk'])['nspdk']\n",
    "\n",
    "    acc = (torch.tensor(adj).to(device_id)[idx_observed] == ground_truth_adjs[idx_observed]).sum() / idx_observed[0].shape[0]\n",
    "    acc = acc.item()\n",
    "    \n",
    "    adj_tensor = torch.tensor(adj).to(device_id)\n",
    "    \n",
    "    accs_zero[n] = acc\n",
    "    scores_zero[n] = scores\n",
    "    scores_nspdk_zero[n] = scores_nspdk"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Loss - Ablation in $\\lambda$"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "guidance_config = load_yaml_config(f'config_guidance/incomplete/{observed}/loss.yaml')\n",
    "guidance_args = edict({'method': 'loss', 'obj': guidance_config['obj'], **guidance_config[configt.data.data.lower()]})\n",
    "guidance_args['loss_kwargs'] = {'idx_obs': idx_observed, 'true_adj': ground_truth_adjs}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "lr_vals = [1e-1, 5e-1, 9e-1, 1.0, 2.0, 5., 10., 50., 100., 150., 200.]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "accs_loss = {}\n",
    "scores_loss = {}\n",
    "scores_nspdk_loss = {}\n",
    "\n",
    "for lr in lr_vals:\n",
    "    guidance_args['lr_guidance'] = lr\n",
    "    predictor_obj_adj = predictor_fn('adj', sde_adj, score_fn_adj, probability_flow, guidance_args=guidance_args)\n",
    "    corrector_obj_adj = corrector_fn('adj', sde_adj, score_fn_adj, snr, scale_eps, n_steps)\n",
    "\n",
    "    adj, x = sample(predictor_obj_x, corrector_obj_x, predictor_obj_adj, corrector_obj_adj, init_x, init_adj, init_flags_iter)\n",
    "    if adj.min() < -1000:\n",
    "        accs_loss[lr] = -1.\n",
    "        continue\n",
    "    adj_mod = adj.copy() - 1\n",
    "    adj_mod[adj_mod == -1] = 3      # 0, 1, 2, 3 (no, S, D, T) -> 3, 0, 1, 2\n",
    "    adj_onehot = torch.nn.functional.one_hot(torch.tensor(adj_mod), num_classes=4).permute(0, 3, 1, 2)\n",
    "    \n",
    "    gen_mols, num_mols_wo_correction = gen_mol(x, adj_onehot, configt.data.data)\n",
    "\n",
    "    gen_graph_list = mols_to_nx(gen_mols)\n",
    "\n",
    "    gen_smiles = mols_to_smiles(gen_mols)\n",
    "    gen_smiles = [smi for smi in gen_smiles if len(smi)]\n",
    "\n",
    "    scores = get_all_metrics(gen=gen_smiles, k=len(gen_smiles), device=device_id, n_jobs=8, test=test_smiles)\n",
    "    scores_nspdk = eval_graph_list(test_graph_list, gen_graph_list, methods=['nspdk'])['nspdk']\n",
    "\n",
    "    acc = (torch.tensor(adj).to(device_id)[idx_observed] == ground_truth_adjs[idx_observed]).sum() / idx_observed[0].shape[0]\n",
    "    acc = acc.item()\n",
    "    \n",
    "    adj_tensor = torch.tensor(adj).to(device_id)    \n",
    "\n",
    "    accs_loss[lr] = acc\n",
    "    scores_loss[lr] = scores\n",
    "    scores_nspdk_loss[lr] = scores_nspdk"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Evaluation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, ax1 = plt.subplots()\n",
    "\n",
    "color = 'tab:red'\n",
    "ax1.set_xlabel('$\\\\lambda$', fontsize=16)\n",
    "ax1.set_ylabel('Accuracy', fontsize=16, color=color)\n",
    "ax1.semilogx(lr_vals, [100*acc for acc in accs_loss.values()], '--', color=color)\n",
    "ax1.tick_params(axis='y', labelsize=12, labelcolor=color)\n",
    "\n",
    "ax2 = ax1.twinx()  # instantiate a second Axes that shares the same x-axis\n",
    "\n",
    "color = 'tab:blue'\n",
    "ax2.set_ylabel('% Unique', fontsize=16, color=color)  # we already handled the x-label with ax1\n",
    "ax2.semilogx(lr_vals, [100*x['unique@10000'] for x in scores_loss.values()], color=color)\n",
    "ax2.tick_params(axis='y', labelsize=12, labelcolor=color)\n",
    "\n",
    "fig.tight_layout()  # otherwise the right y-label is slightly clipped"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(\"$\\\\lambda$ | \" + \"|\".join([str(x) for x in lr_vals]))\n",
    "print(\"--------\" + \" | --\"*len(lr_vals))\n",
    "print(\"Accuracy (%) | \" + \"|\".join([str(round(100*x,2)) for x in accs_loss.values()]))\n",
    "print(\"% Unique | \" + \"|\".join([str(round(100*x['unique@10000'],2)) for x in scores_loss.values()]))\n",
    "print(\"Novelty | \" + \"|\".join([str(round(100*x['Novelty'],2)) for x in scores_loss.values()]))\n",
    "print(\"FCD | \" + \"|\".join([str(round(x['FCD/Test'],2)) for x in scores_loss.values()]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, ax1 = plt.subplots()\n",
    "\n",
    "color = 'tab:red'\n",
    "ax1.set_xlabel('$N$', fontsize=16)\n",
    "ax1.set_xticks(N_vals)\n",
    "ax1.set_xticklabels(N_vals)\n",
    "ax1.set_ylabel('Accuracy', fontsize=16, color=color)\n",
    "ax1.plot(N_vals, [100*acc for acc in accs_greedy.values()], '--', color=color)\n",
    "ax1.tick_params(axis='y', labelsize=12, labelcolor=color)\n",
    "\n",
    "ax2 = ax1.twinx()  # instantiate a second Axes that shares the same x-axis\n",
    "\n",
    "color = 'tab:blue'\n",
    "ax2.set_ylabel('% Unique', fontsize=16, color=color)  # we already handled the x-label with ax1\n",
    "ax2.plot(N_vals, [100*x['unique@10000'] for x in scores_greedy.values()], color=color)\n",
    "ax2.tick_params(axis='y', labelsize=12, labelcolor=color)\n",
    "\n",
    "fig.tight_layout()  # otherwise the right y-label is slightly clipped"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(\"$N$ | \" + \"|\".join([str(x) for x in N_vals]))\n",
    "print(\"---\" + \" | --\"*len(N_vals))\n",
    "print(\"Accuracy (%) | \" + \"|\".join([str(round(100*x,2)) for x in accs_greedy.values()]))\n",
    "print(\"% Unique | \" + \"|\".join([str(round(100*x['unique@10000'],2)) for x in scores_greedy.values()]))\n",
    "print(\"Novelty | \" + \"|\".join([str(round(100*x['Novelty'],2)) for x in scores_greedy.values()]))\n",
    "print(\"FCD | \" + \"|\".join([str(round(x['FCD/Test'],2)) for x in scores_greedy.values()]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "with open('results/ablation/incomplete_mol.pkl', 'wb') as f:\n",
    "    pickle.dump({\n",
    "        'N_vals': N_vals, 'lr_vals': lr_vals,\n",
    "        'accs_greedy': accs_greedy, 'scores_greedy': scores_greedy,\n",
    "        'accs_loss': accs_loss, 'scores_loss': scores_loss,\n",
    "        'accs_zero': accs_zero, 'scores_zero': scores_zero}, f\n",
    "    )"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "digress-env",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
