{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import torch\n",
    "import networkx as nx\n",
    "import matplotlib.pyplot as plt\n",
    "from easydict import EasyDict as edict\n",
    "\n",
    "from losses import get_score_fn\n",
    "from solver_guidance import ReverseDiffusionPredictor, EulerMaruyamaPredictor, LangevinCorrector, NoneCorrector\n",
    "from utils.graph_utils import mask_adjs, mask_x\n",
    "from tqdm.notebook import trange\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "\n",
    "from parsers.config import get_config\n",
    "\n",
    "from utils.logger import Logger, set_log, start_log, train_log, sample_log, check_log\n",
    "from utils.loader import load_ckpt, load_data, load_seed, load_device, load_model_from_ckpt, \\\n",
    "                         load_ema_from_ckpt, load_sde, load_yaml_config\n",
    "from utils.graph_utils import adjs_to_graphs, init_flags, quantize\n",
    "from utils.plot import save_graph_list, plot_graphs_list\n",
    "\n",
    "from prodigy.project_bisection import drifted_project, get_constraint_config, get_method_config, satisfies\n",
    "import sys\n",
    "sys.path.append('prodigy')\n",
    "from prodigy.evals.evaluate_gdss import evaluate"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Data and preamble"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "device = load_device()\n",
    "device = [1]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "config_file = 'sample_ego_small'\n",
    "seed = 0\n",
    "config = get_config(config_file, seed)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# -------- Load checkpoint --------\n",
    "ckpt_dict = load_ckpt(config, device)\n",
    "configt = ckpt_dict['config']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "load_seed(configt.seed)\n",
    "train_graph_list, test_graph_list = load_data(configt, get_graph_list=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "log_folder_name, log_dir, _ = set_log(configt, is_train=False)\n",
    "log_name = f\"{config.ckpt}-sample-guidance\"\n",
    "logger = Logger(str(os.path.join(log_dir, f'{log_name}.log')), mode='a')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "configt.data.init = \"deg\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "n_cycles_test = np.array([len(list(nx.simple_cycles(g))) for g in test_graph_list])\n",
    "n_cycles_test.mean(), n_cycles_test.std()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, ax = plt.subplots(1, 2, figsize=(10, 5))\n",
    "nx.draw(test_graph_list[3], ax=ax[0])\n",
    "nx.draw(test_graph_list[7], ax=ax[1])\n",
    "\n",
    "ax[0].set_title(f\"n_cycles: {n_cycles_test[3]}\")\n",
    "ax[1].set_title(f\"n_cycles: {n_cycles_test[7]}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "constraint_name = 'ncycles' # nedges, ntriangles, degree\n",
    "is_star = lambda g: nx.is_isomorphic(g, nx.star_graph(g.number_of_nodes()-1))\n",
    "is_egonet = lambda g: any(degree == g.number_of_nodes() - 1 for _, degree in g.degree())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "is_egonet(test_graph_list[7]), is_egonet(test_graph_list[3])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if not check_log(log_folder_name, log_name):\n",
    "    logger.log(f'{log_name}')\n",
    "    start_log(logger, configt)\n",
    "    train_log(logger, configt)\n",
    "sample_log(logger, config)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# -------- Load models --------\n",
    "model_x = load_model_from_ckpt(ckpt_dict['params_x'], ckpt_dict['x_state_dict'], device)\n",
    "model_adj = load_model_from_ckpt(ckpt_dict['params_adj'], ckpt_dict['adj_state_dict'], device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if config.sample.use_ema:\n",
    "    ema_x = load_ema_from_ckpt(model_x, ckpt_dict['ema_x'], configt.train.ema)\n",
    "    ema_adj = load_ema_from_ckpt(model_adj, ckpt_dict['ema_adj'], configt.train.ema)\n",
    "    \n",
    "    ema_x.copy_to(model_x.parameters())\n",
    "    ema_adj.copy_to(model_adj.parameters())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(f'GEN SEED: {config.sample.seed}')\n",
    "load_seed(config.sample.seed)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Generation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sde_x = load_sde(configt.sde.x)\n",
    "sde_adj = load_sde(configt.sde.adj)\n",
    "max_node_num  = configt.data.max_node_num\n",
    "\n",
    "device_id = f'cuda:{device[0]}' if isinstance(device, list) else device"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "continuous = True\n",
    "\n",
    "predictor = config.sampler.predictor\n",
    "corrector = config.sampler.corrector\n",
    "\n",
    "snr=config.sampler.snr\n",
    "scale_eps=config.sampler.scale_eps\n",
    "n_steps=config.sampler.n_steps\n",
    "\n",
    "probability_flow = config.sample.probability_flow\n",
    "eps = config.sample.eps\n",
    "denoise = config.sample.noise_removal"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if configt.data.data in ['QM9']: # 'ZINC250k' gives OOM\n",
    "    shape_x = (10000, max_node_num, configt.data.max_feat_num)\n",
    "    shape_adj = (10000, max_node_num, max_node_num)\n",
    "else:\n",
    "    shape_x = (configt.data.batch_size, max_node_num, configt.data.max_feat_num)\n",
    "    shape_adj = (configt.data.batch_size, max_node_num, max_node_num)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "init_flags_iter = init_flags(train_graph_list, configt).to(device_id)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "init_x = sde_x.prior_sampling(shape_x).to(device_id)\n",
    "init_adj = sde_adj.prior_sampling_sym(shape_adj).to(device_id)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def sample(predictor_x, corrector_x,\n",
    "           predictor_adj, corrector_adj,\n",
    "           init_x=None, init_adj=None, flags=None):\n",
    "    with torch.no_grad():\n",
    "        # -------- Initial sample --------\n",
    "        if init_x is not None:\n",
    "            x = init_x.clone()\n",
    "        else:\n",
    "            x = predictor_x.sde.prior_sampling(shape_x).to(device_id)\n",
    "        if init_adj is not None:\n",
    "            adj = init_adj.clone()\n",
    "        else:\n",
    "            adj = predictor_adj.sde.prior_sampling_sym(shape_adj).to(device_id)\n",
    "        \n",
    "        if flags is None:\n",
    "            flags = init_flags(train_graph_list, configt).to(device_id)\n",
    "        x = mask_x(x, flags)\n",
    "        adj = mask_adjs(adj, flags)\n",
    "        diff_steps = predictor_adj.sde.N\n",
    "        timesteps = torch.linspace(predictor_adj.sde.T, eps, diff_steps, device=device_id)\n",
    "\n",
    "        # -------- Reverse diffusion process --------\n",
    "        for i in trange(0, (diff_steps), desc = '[Sampling]', position = 1, leave=False):\n",
    "            # print(i, end=\" \", flush=True)\n",
    "            t = timesteps[i]\n",
    "            vec_t = torch.ones(shape_adj[0], device=t.device) * t\n",
    "\n",
    "            _x = x\n",
    "            x, x_mean = corrector_x.update_fn(x, adj, flags, vec_t)\n",
    "            adj, adj_mean = corrector_adj.update_fn(_x, adj, flags, vec_t)\n",
    "\n",
    "            if torch.any(torch.isnan(adj)):\n",
    "                break\n",
    "\n",
    "            _x = x\n",
    "            x, x_mean = predictor_x.update_fn(x, adj, flags, vec_t)\n",
    "            adj, adj_mean = predictor_adj.update_fn(_x, adj, flags, vec_t)\n",
    "        print(' ')\n",
    "    samples_int = quantize(adj)\n",
    "    gen_graph_list = adjs_to_graphs(samples_int, True)\n",
    "    return gen_graph_list"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Greedy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "guidance_config = load_yaml_config(f'config_guidance/constrained/{constraint_name}/greedy.yaml')\n",
    "guidance_args = edict({'method': 'greedy', 'obj': guidance_config['obj'], **guidance_config[configt.data.data.lower()]})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "score_fn_x = get_score_fn(sde_x, model_x, train=False, continuous=continuous)\n",
    "score_fn_adj = get_score_fn(sde_adj, model_adj, train=False, continuous=continuous)\n",
    "\n",
    "predictor_fn = ReverseDiffusionPredictor if predictor=='Reverse' else EulerMaruyamaPredictor \n",
    "corrector_fn = LangevinCorrector if corrector=='Langevin' else NoneCorrector\n",
    "\n",
    "predictor_obj_x = predictor_fn('x', sde_x, score_fn_x, probability_flow)\n",
    "corrector_obj_x = corrector_fn('x', sde_x, score_fn_x, snr, scale_eps, n_steps)\n",
    "\n",
    "predictor_obj_adj = predictor_fn('adj', sde_adj, score_fn_adj, probability_flow, guidance_args=guidance_args)\n",
    "corrector_obj_adj = corrector_fn('adj', sde_adj, score_fn_adj, snr, scale_eps, n_steps)\n",
    "\n",
    "gen_graph_list_greedy = sample(predictor_obj_x, corrector_obj_x, predictor_obj_adj, corrector_obj_adj, init_x, init_adj, init_flags_iter)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "n_cycles_greedy = np.array([len(list(nx.simple_cycles(g))) for g in gen_graph_list_greedy])\n",
    "valid_greedy = np.array([is_egonet(g) for g in gen_graph_list_greedy])\n",
    "n_cycles_greedy.mean(), valid_greedy.mean()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Zero order optimization"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "guidance_config = load_yaml_config(f'config_guidance/constrained/{constraint_name}/zero.yaml')\n",
    "guidance_args = edict({'method': 'zero', 'obj': guidance_config['obj'], **guidance_config[configt.data.data.lower()]})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "score_fn_x = get_score_fn(sde_x, model_x, train=False, continuous=continuous)\n",
    "score_fn_adj = get_score_fn(sde_adj, model_adj, train=False, continuous=continuous)\n",
    "\n",
    "predictor_fn = ReverseDiffusionPredictor if predictor=='Reverse' else EulerMaruyamaPredictor \n",
    "corrector_fn = LangevinCorrector if corrector=='Langevin' else NoneCorrector\n",
    "\n",
    "predictor_obj_x = predictor_fn('x', sde_x, score_fn_x, probability_flow)\n",
    "corrector_obj_x = corrector_fn('x', sde_x, score_fn_x, snr, scale_eps, n_steps)\n",
    "\n",
    "predictor_obj_adj = predictor_fn('adj', sde_adj, score_fn_adj, probability_flow, guidance_args=guidance_args)\n",
    "corrector_obj_adj = corrector_fn('adj', sde_adj, score_fn_adj, snr, scale_eps, n_steps)\n",
    "\n",
    "gen_graph_list_zero = sample(predictor_obj_x, corrector_obj_x, predictor_obj_adj, corrector_obj_adj, init_x, init_adj, init_flags_iter)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "n_cycles_zero = np.array([len(list(nx.simple_cycles(g))) for g in gen_graph_list_zero])\n",
    "valid_zero = np.array([is_egonet(g) for g in gen_graph_list_zero])\n",
    "n_cycles_zero.mean(), valid_zero.mean()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Unconstrained"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "score_fn_x = get_score_fn(sde_x, model_x, train=False, continuous=continuous)\n",
    "score_fn_adj = get_score_fn(sde_adj, model_adj, train=False, continuous=continuous)\n",
    "\n",
    "predictor_fn = ReverseDiffusionPredictor if predictor=='Reverse' else EulerMaruyamaPredictor \n",
    "corrector_fn = LangevinCorrector if corrector=='Langevin' else NoneCorrector\n",
    "\n",
    "predictor_obj_x = predictor_fn('x', sde_x, score_fn_x, probability_flow)\n",
    "corrector_obj_x = corrector_fn('x', sde_x, score_fn_x, snr, scale_eps, n_steps)\n",
    "\n",
    "predictor_obj_adj = predictor_fn('adj', sde_adj, score_fn_adj, probability_flow)\n",
    "corrector_obj_adj = corrector_fn('adj', sde_adj, score_fn_adj, snr, scale_eps, n_steps)\n",
    "\n",
    "gen_graph_list_uncons = sample(predictor_obj_x, corrector_obj_x, predictor_obj_adj, corrector_obj_adj, init_x, init_adj, init_flags_iter)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "n_cycles_uncons = np.array([len(list(nx.simple_cycles(g))) for g in gen_graph_list_uncons])\n",
    "valid_uncons = np.array([is_egonet(g) for g in gen_graph_list_uncons])\n",
    "n_cycles_uncons.mean(), valid_uncons.mean()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Evaluation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "f, ax = plt.subplots(4, 5, figsize=(20,12))\n",
    "\n",
    "for i in range(5):\n",
    "    j = np.random.randint(len(test_graph_list))\n",
    "    nx.draw(test_graph_list[j], ax=ax[0, i], node_size=10)\n",
    "    nx.draw(gen_graph_list_greedy[j], ax=ax[1, i], node_size=10)\n",
    "    nx.draw(gen_graph_list_zero[j], ax=ax[2, i], node_size=10)\n",
    "    nx.draw(gen_graph_list_uncons[j], ax=ax[3, i], node_size=10)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.hist(n_cycles_greedy, alpha=0.5, bins=20, label='Greedy')\n",
    "plt.hist(n_cycles_zero, alpha=0.5, bins=20, label='Loss')\n",
    "plt.hist(n_cycles_uncons, alpha=0.5, bins=20, label='Uncons')\n",
    "plt.legend()\n",
    "plt.title(\"Distribution of the number of cycles\", fontsize=18)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "\n",
    "df_results = pd.DataFrame({\n",
    "    'Method': ['GGDiff-C', 'GGDiff-Z', 'Uncons.'],\n",
    "    '#Cycles': [f\"{n_cycles_greedy.mean().round(2)} $\\\\pm$ {n_cycles_greedy.std().round(2)}\", f\"{n_cycles_zero.mean().round(2)} $\\\\pm$ {n_cycles_zero.std().round(2)}\", f\"{n_cycles_uncons.mean().round(2)} $\\\\pm$ {n_cycles_uncons.std().round(2)}\"],\n",
    "    'Valid %': [100*valid_greedy.mean().round(2), 100*valid_zero.mean().round(2), 100*valid_uncons.mean().round(2)]\n",
    "})\n",
    "df_results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(df_results.to_latex(index=False))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import math\n",
    "\n",
    "options = {\n",
    "    'node_size': 2,\n",
    "    'edge_color' : 'black',\n",
    "    'linewidths': 1,\n",
    "    'width': 0.5\n",
    "}\n",
    "\n",
    "def plot_graphs_list(graphs, max_num=16, N=0):\n",
    "    \n",
    "    batch_size = len(graphs)\n",
    "    max_num = min(batch_size, max_num)\n",
    "    img_c = int(math.ceil(np.sqrt(max_num)))\n",
    "    figure = plt.figure()\n",
    "\n",
    "    for i in range(max_num):\n",
    "        # idx = i * (batch_size // max_num)\n",
    "        idx = i + max_num*N\n",
    "        if not isinstance(graphs[idx], nx.Graph):\n",
    "            G = graphs[idx].g.copy()\n",
    "        else:\n",
    "            G = graphs[idx].copy()\n",
    "        assert isinstance(G, nx.Graph)\n",
    "        G.remove_nodes_from(list(nx.isolates(G)))\n",
    "\n",
    "        ax = plt.subplot(img_c, img_c, i + 1)\n",
    "        pos = nx.spring_layout(G)\n",
    "        nx.draw(G, pos, with_labels=False, **options)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_graphs_list(gen_graph_list_greedy, f'results/cycles', max_num=25)\n",
    "plt.savefig(\"results/cycles/greedy_samples.pdf\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_graphs_list(gen_graph_list_zero, f'results/cycles', max_num=25)\n",
    "plt.savefig(\"results/cycles/zero_samples.pdf\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_graphs_list(gen_graph_list_uncons, f'results/cycles', max_num=25)\n",
    "plt.savefig(\"results/cycles/uncons_samples.pdf\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "digress-env",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
