{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import torch\n",
    "import networkx as nx\n",
    "import matplotlib.pyplot as plt\n",
    "from easydict import EasyDict as edict\n",
    "\n",
    "from losses import get_score_fn\n",
    "from solver_guidance import ReverseDiffusionPredictor, EulerMaruyamaPredictor, LangevinCorrector, NoneCorrector\n",
    "from utils.graph_utils import mask_adjs, mask_x\n",
    "from tqdm.notebook import trange\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "\n",
    "from parsers.config import get_config\n",
    "\n",
    "from utils.logger import Logger, set_log, start_log, train_log, sample_log, check_log\n",
    "from utils.loader import load_ckpt, load_data, load_seed, load_device, load_model_from_ckpt, \\\n",
    "                         load_ema_from_ckpt, load_sde, load_yaml_config\n",
    "from utils.graph_utils import adjs_to_graphs, init_flags, quantize\n",
    "\n",
    "from prodigy.project_bisection import drifted_project, get_constraint_config, get_method_config, satisfies\n",
    "import sys\n",
    "sys.path.append('prodigy')\n",
    "from prodigy.evals.evaluate_gdss import evaluate"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Data and preamble"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "device = load_device()\n",
    "device = [0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "config_file = 'sample_ego_small'\n",
    "seed = 0\n",
    "config = get_config(config_file, seed)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# -------- Load checkpoint --------\n",
    "ckpt_dict = load_ckpt(config, device)\n",
    "configt = ckpt_dict['config']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "load_seed(configt.seed)\n",
    "train_graph_list, test_graph_list = load_data(configt, get_graph_list=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "log_folder_name, log_dir, _ = set_log(configt, is_train=False)\n",
    "log_name = f\"{config.ckpt}-sample-guidance\"\n",
    "logger = Logger(str(os.path.join(log_dir, f'{log_name}.log')), mode='a')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "configt.data.init = \"deg\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "n_edges_test = np.array([g.number_of_edges() for g in test_graph_list])\n",
    "n_triangles_test = np.array([sum(list(nx.triangles(g).values())) for g in test_graph_list])\n",
    "max_degrees_test = np.array([max([x[1] for x in g.degree()]) for g in test_graph_list])\n",
    "\n",
    "max_edges = np.percentile(n_edges_test, 10)\n",
    "max_triangles = np.percentile(n_triangles_test, 10)\n",
    "max_degree = np.percentile(max_degrees_test, 10)\n",
    "constraint_param_map = {'nedges': max_edges, 'ntriangles': max_triangles, 'degree': max_degree}\n",
    "max_edges, max_triangles, max_degree"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "constraint_name = 'nedges' # nedges, ntriangles, degree\n",
    "constraint_config = get_constraint_config(f'prodigy/configs/{constraint_name}/constraint.yaml')\n",
    "method_config = get_method_config(f'prodigy/configs/{constraint_name}/method.yaml')\n",
    "constraint_config.params[-1] = constraint_param_map[constraint_name]\n",
    "constraint_config"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "satisfies_nedges = lambda g: g.number_of_edges() <= max_edges\n",
    "satisfies_degree = lambda g: max([x[1] for x in g.degree()]) <= max_degree\n",
    "satisfies_ntriangles = lambda g: sum(list(nx.triangles(g).values())) <= max_triangles\n",
    "satisfies_force_stars = lambda g: g.number_of_edges() == (g.number_of_nodes() - 1)\n",
    "\n",
    "satisfies_fn = eval('satisfies_' + constraint_name)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if not check_log(log_folder_name, log_name):\n",
    "    logger.log(f'{log_name}')\n",
    "    start_log(logger, configt)\n",
    "    train_log(logger, configt)\n",
    "sample_log(logger, config)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# -------- Load models --------\n",
    "model_x = load_model_from_ckpt(ckpt_dict['params_x'], ckpt_dict['x_state_dict'], device)\n",
    "model_adj = load_model_from_ckpt(ckpt_dict['params_adj'], ckpt_dict['adj_state_dict'], device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if config.sample.use_ema:\n",
    "    ema_x = load_ema_from_ckpt(model_x, ckpt_dict['ema_x'], configt.train.ema)\n",
    "    ema_adj = load_ema_from_ckpt(model_adj, ckpt_dict['ema_adj'], configt.train.ema)\n",
    "    \n",
    "    ema_x.copy_to(model_x.parameters())\n",
    "    ema_adj.copy_to(model_adj.parameters())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(f'GEN SEED: {config.sample.seed}')\n",
    "load_seed(config.sample.seed)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Generation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sde_x = load_sde(configt.sde.x)\n",
    "sde_adj = load_sde(configt.sde.adj)\n",
    "max_node_num  = configt.data.max_node_num\n",
    "\n",
    "device_id = f'cuda:{device[0]}' if isinstance(device, list) else device"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "continuous = True\n",
    "\n",
    "predictor = config.sampler.predictor\n",
    "corrector = config.sampler.corrector\n",
    "\n",
    "snr=config.sampler.snr\n",
    "scale_eps=config.sampler.scale_eps\n",
    "n_steps=config.sampler.n_steps\n",
    "\n",
    "probability_flow = config.sample.probability_flow\n",
    "eps = config.sample.eps\n",
    "denoise = config.sample.noise_removal"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if configt.data.data in ['QM9']: # 'ZINC250k' gives OOM\n",
    "    shape_x = (10000, max_node_num, configt.data.max_feat_num)\n",
    "    shape_adj = (10000, max_node_num, max_node_num)\n",
    "else:\n",
    "    shape_x = (configt.data.batch_size, max_node_num, configt.data.max_feat_num)\n",
    "    shape_adj = (configt.data.batch_size, max_node_num, max_node_num)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "init_flags_iter = init_flags(train_graph_list, configt).to(device_id)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "init_x = sde_x.prior_sampling(shape_x).to(device_id)\n",
    "init_adj = sde_adj.prior_sampling_sym(shape_adj).to(device_id)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def sample(predictor_x, corrector_x,\n",
    "           predictor_adj, corrector_adj,\n",
    "           init_x=None, init_adj=None, flags=None):\n",
    "    with torch.no_grad():\n",
    "        # -------- Initial sample --------\n",
    "        if init_x is not None:\n",
    "            x = init_x.clone()\n",
    "        else:\n",
    "            x = predictor_x.sde.prior_sampling(shape_x).to(device_id)\n",
    "        if init_adj is not None:\n",
    "            adj = init_adj.clone()\n",
    "        else:\n",
    "            adj = predictor_adj.sde.prior_sampling_sym(shape_adj).to(device_id)\n",
    "        \n",
    "        if flags is None:\n",
    "            flags = init_flags(train_graph_list, configt).to(device_id)\n",
    "        x = mask_x(x, flags)\n",
    "        adj = mask_adjs(adj, flags)\n",
    "        diff_steps = predictor_adj.sde.N\n",
    "        timesteps = torch.linspace(predictor_adj.sde.T, eps, diff_steps, device=device_id)\n",
    "\n",
    "        # -------- Reverse diffusion process --------\n",
    "        for i in trange(0, (diff_steps), desc = '[Sampling]', position = 1, leave=False):\n",
    "            t = timesteps[i]\n",
    "            vec_t = torch.ones(shape_adj[0], device=t.device) * t\n",
    "\n",
    "            _x = x\n",
    "            x, x_mean = corrector_x.update_fn(x, adj, flags, vec_t)\n",
    "            adj, adj_mean = corrector_adj.update_fn(_x, adj, flags, vec_t)\n",
    "\n",
    "            _x = x\n",
    "            x, x_mean = predictor_x.update_fn(x, adj, flags, vec_t)\n",
    "            adj, adj_mean = predictor_adj.update_fn(_x, adj, flags, vec_t)\n",
    "        print(' ')\n",
    "    samples_int = quantize(adj)\n",
    "    gen_graph_list = adjs_to_graphs(samples_int, True)\n",
    "    return gen_graph_list"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Greedy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "guidance_config = load_yaml_config(f'config_guidance/constrained/{constraint_name}/greedy.yaml')\n",
    "guidance_args = edict({'method': 'greedy', 'obj': guidance_config['obj'], **guidance_config[configt.data.data.lower()]})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "score_fn_x = get_score_fn(sde_x, model_x, train=False, continuous=continuous)\n",
    "score_fn_adj = get_score_fn(sde_adj, model_adj, train=False, continuous=continuous)\n",
    "\n",
    "predictor_fn = ReverseDiffusionPredictor if predictor=='Reverse' else EulerMaruyamaPredictor \n",
    "corrector_fn = LangevinCorrector if corrector=='Langevin' else NoneCorrector\n",
    "\n",
    "predictor_obj_x = predictor_fn('x', sde_x, score_fn_x, probability_flow)\n",
    "corrector_obj_x = corrector_fn('x', sde_x, score_fn_x, snr, scale_eps, n_steps)\n",
    "\n",
    "predictor_obj_adj = predictor_fn('adj', sde_adj, score_fn_adj, probability_flow, guidance_args=guidance_args)\n",
    "corrector_obj_adj = corrector_fn('adj', sde_adj, score_fn_adj, snr, scale_eps, n_steps)\n",
    "\n",
    "gen_graph_list_greedy = sample(predictor_obj_x, corrector_obj_x, predictor_obj_adj, corrector_obj_adj, init_x, init_adj, init_flags_iter)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "n_nodes_greedy = np.array([g.number_of_nodes() for g in gen_graph_list_greedy])\n",
    "n_edges_greedy = np.array([g.number_of_edges() for g in gen_graph_list_greedy])\n",
    "n_triangles_greedy = np.array([sum(list(nx.triangles(g).values())) for g in gen_graph_list_greedy])\n",
    "max_degree_greedy = np.array([max([x[1] for x in g.degree()]) for g in gen_graph_list_greedy])\n",
    "n_edges_greedy.mean(), n_edges_greedy.std()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Loss"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "guidance_config = load_yaml_config(f'config_guidance/constrained/{constraint_name}/loss.yaml')\n",
    "guidance_args = edict({'method': 'loss', 'obj': guidance_config['obj'], **guidance_config[configt.data.data.lower()]})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "score_fn_x = get_score_fn(sde_x, model_x, train=False, continuous=continuous)\n",
    "score_fn_adj = get_score_fn(sde_adj, model_adj, train=False, continuous=continuous)\n",
    "\n",
    "predictor_fn = ReverseDiffusionPredictor if predictor=='Reverse' else EulerMaruyamaPredictor \n",
    "corrector_fn = LangevinCorrector if corrector=='Langevin' else NoneCorrector\n",
    "\n",
    "predictor_obj_x = predictor_fn('x', sde_x, score_fn_x, probability_flow)\n",
    "corrector_obj_x = corrector_fn('x', sde_x, score_fn_x, snr, scale_eps, n_steps)\n",
    "\n",
    "predictor_obj_adj = predictor_fn('adj', sde_adj, score_fn_adj, probability_flow, guidance_args=guidance_args)\n",
    "corrector_obj_adj = corrector_fn('adj', sde_adj, score_fn_adj, snr, scale_eps, n_steps)\n",
    "\n",
    "# if config.sampler.predictor == \"S4\":\n",
    "#     gen_graph_list_loss = sample_S4(score_fn_x, score_fn_adj,\n",
    "#                                       sde_x, sde_adj,\n",
    "#                                       guidance_args_adj=guidance_args_loss, guidance_args_x=None,\n",
    "#                                       init_x=init_x, init_adj=init_adj, flags=init_flags_iter)\n",
    "# else:\n",
    "gen_graph_list_loss = sample(predictor_obj_x, corrector_obj_x, predictor_obj_adj, corrector_obj_adj, init_x, init_adj, init_flags_iter)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "n_nodes_loss = np.array([g.number_of_nodes() for g in gen_graph_list_loss])\n",
    "n_edges_loss = np.array([g.number_of_edges() for g in gen_graph_list_loss])\n",
    "n_triangles_loss = np.array([sum(list(nx.triangles(g).values())) for g in gen_graph_list_loss])\n",
    "max_degree_loss = np.array([max([x[1] for x in g.degree()]) for g in gen_graph_list_loss])\n",
    "n_edges_loss.mean(), n_edges_loss.std()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Zero order optimization"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "guidance_config = load_yaml_config(f'config_guidance/constrained/{constraint_name}/zero.yaml')\n",
    "guidance_args = edict({'method': 'zero', 'obj': guidance_config['obj'], **guidance_config[configt.data.data.lower()]})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "score_fn_x = get_score_fn(sde_x, model_x, train=False, continuous=continuous)\n",
    "score_fn_adj = get_score_fn(sde_adj, model_adj, train=False, continuous=continuous)\n",
    "\n",
    "predictor_fn = ReverseDiffusionPredictor if predictor=='Reverse' else EulerMaruyamaPredictor \n",
    "corrector_fn = LangevinCorrector if corrector=='Langevin' else NoneCorrector\n",
    "\n",
    "predictor_obj_x = predictor_fn('x', sde_x, score_fn_x, probability_flow)\n",
    "corrector_obj_x = corrector_fn('x', sde_x, score_fn_x, snr, scale_eps, n_steps)\n",
    "\n",
    "predictor_obj_adj = predictor_fn('adj', sde_adj, score_fn_adj, probability_flow, guidance_args=guidance_args)\n",
    "corrector_obj_adj = corrector_fn('adj', sde_adj, score_fn_adj, snr, scale_eps, n_steps)\n",
    "\n",
    "gen_graph_list_zero = sample(predictor_obj_x, corrector_obj_x, predictor_obj_adj, corrector_obj_adj, init_x, init_adj, init_flags_iter)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "n_nodes_zero = np.array([g.number_of_nodes() for g in gen_graph_list_zero])\n",
    "n_edges_zero = np.array([g.number_of_edges() for g in gen_graph_list_zero])\n",
    "n_triangles_zero = np.array([sum(list(nx.triangles(g).values())) for g in gen_graph_list_zero])\n",
    "max_degree_zero = np.array([max([x[1] for x in g.degree()]) for g in gen_graph_list_zero])\n",
    "n_edges_zero.mean(), n_edges_zero.std()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Prodigy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "score_fn_x = get_score_fn(sde_x, model_x, train=False, continuous=continuous)\n",
    "score_fn_adj = get_score_fn(sde_adj, model_adj, train=False, continuous=continuous)\n",
    "\n",
    "predictor_fn = ReverseDiffusionPredictor if predictor=='Reverse' else EulerMaruyamaPredictor \n",
    "corrector_fn = LangevinCorrector if corrector=='Langevin' else NoneCorrector\n",
    "\n",
    "predictor_obj_x = predictor_fn('x', sde_x, score_fn_x, probability_flow)\n",
    "corrector_obj_x = corrector_fn('x', sde_x, score_fn_x, snr, scale_eps, n_steps)\n",
    "\n",
    "predictor_obj_adj = predictor_fn('adj', sde_adj, score_fn_adj, probability_flow)\n",
    "corrector_obj_adj = corrector_fn('adj', sde_adj, score_fn_adj, snr, scale_eps, n_steps)\n",
    "\n",
    "with torch.no_grad():\n",
    "    # -------- Initial sample --------\n",
    "    x = init_x.clone()\n",
    "    adj = init_adj.clone() \n",
    "     \n",
    "    flags = init_flags_iter.clone()\n",
    "    x = mask_x(x, flags)\n",
    "    adj = mask_adjs(adj, flags)\n",
    "    diff_steps = sde_adj.N\n",
    "    timesteps = torch.linspace(sde_adj.T, eps, diff_steps, device=device_id)\n",
    "\n",
    "    # -------- Reverse diffusion process --------\n",
    "    for i in trange(0, (diff_steps), desc = '[Sampling]', position = 1, leave=False):\n",
    "        t = timesteps[i]\n",
    "        vec_t = torch.ones(shape_adj[0], device=t.device) * t\n",
    "\n",
    "        _x = x\n",
    "        x, x_mean = corrector_obj_x.update_fn(x, adj, flags, vec_t)\n",
    "        adj, adj_mean = corrector_obj_adj.update_fn(_x, adj, flags, vec_t)\n",
    "\n",
    "        _x = x\n",
    "        x, x_mean = predictor_obj_x.update_fn(x, adj, flags, vec_t)\n",
    "        adj, adj_mean = predictor_obj_adj.update_fn(_x, adj, flags, vec_t)\n",
    "\n",
    "        x, adj = drifted_project(x, adj, i=i, diff_steps=diff_steps, constraint_config=constraint_config, method_config=method_config)\n",
    "        x_mean, adj_mean = drifted_project(x_mean, adj_mean, i=i, diff_steps=diff_steps, constraint_config=constraint_config, method_config=method_config)\n",
    "    print(' ')\n",
    "samples_int_prodigy = quantize(adj)\n",
    "gen_graph_list_prodigy = adjs_to_graphs(samples_int_prodigy, True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "n_nodes_prodigy = np.array([g.number_of_nodes() for g in gen_graph_list_prodigy])\n",
    "n_edges_prodigy = np.array([g.number_of_edges() for g in gen_graph_list_prodigy])\n",
    "n_triangles_prodigy = np.array([sum(list(nx.triangles(g).values())) for g in gen_graph_list_prodigy])\n",
    "max_degree_prodigy = np.array([max([x[1] for x in g.degree()]) for g in gen_graph_list_prodigy])\n",
    "n_edges_prodigy.mean(), n_edges_prodigy.std()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Unconstrained"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "score_fn_x = get_score_fn(sde_x, model_x, train=False, continuous=continuous)\n",
    "score_fn_adj = get_score_fn(sde_adj, model_adj, train=False, continuous=continuous)\n",
    "\n",
    "predictor_fn = ReverseDiffusionPredictor if predictor=='Reverse' else EulerMaruyamaPredictor \n",
    "corrector_fn = LangevinCorrector if corrector=='Langevin' else NoneCorrector\n",
    "\n",
    "predictor_obj_x = predictor_fn('x', sde_x, score_fn_x, probability_flow)\n",
    "corrector_obj_x = corrector_fn('x', sde_x, score_fn_x, snr, scale_eps, n_steps)\n",
    "\n",
    "predictor_obj_adj = predictor_fn('adj', sde_adj, score_fn_adj, probability_flow)\n",
    "corrector_obj_adj = corrector_fn('adj', sde_adj, score_fn_adj, snr, scale_eps, n_steps)\n",
    "\n",
    "gen_graph_list_uncons = sample(predictor_obj_x, corrector_obj_x, predictor_obj_adj, corrector_obj_adj, init_x, init_adj, init_flags_iter)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "n_nodes_uncons = np.array([g.number_of_nodes() for g in gen_graph_list_uncons])\n",
    "n_edges_uncons = np.array([g.number_of_edges() for g in gen_graph_list_uncons])\n",
    "n_triangles_uncons = np.array([sum(list(nx.triangles(g).values())) for g in gen_graph_list_uncons])\n",
    "max_degree_uncons = np.array([max([x[1] for x in g.degree()]) for g in gen_graph_list_uncons])\n",
    "n_edges_uncons.mean(), n_edges_uncons.std()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Evaluation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "f, ax = plt.subplots(5, 5, figsize=(20,12))\n",
    "\n",
    "for i in range(5):\n",
    "    j = np.random.randint(len(gen_graph_list_greedy))\n",
    "    nx.draw(test_graph_list[j], ax=ax[0, i], node_size=10)\n",
    "    nx.draw(gen_graph_list_greedy[j], ax=ax[1, i], node_size=10)\n",
    "    nx.draw(gen_graph_list_loss[j], ax=ax[2, i], node_size=10)\n",
    "    nx.draw(gen_graph_list_prodigy[j], ax=ax[3, i], node_size=10)\n",
    "    nx.draw(gen_graph_list_uncons[j], ax=ax[4, i], node_size=10)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.hist(n_edges_greedy, alpha=0.5, bins=20, label='Greedy')\n",
    "plt.hist(n_edges_loss, alpha=0.5, bins=20, label='Loss')\n",
    "plt.hist(n_edges_prodigy, alpha=0.5, bins=20, label='Prodigy')\n",
    "plt.hist(n_edges_uncons, alpha=0.5, bins=20, label='Uncons')\n",
    "plt.legend()\n",
    "plt.title(\"Distribution of the number of edges\", fontsize=18)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.hist(n_edges_greedy - (n_nodes_greedy - 1), alpha=0.5, bins=20, label='Greedy')\n",
    "plt.hist(n_edges_loss - (n_nodes_loss - 1), alpha=0.5, bins=20, label='Loss')\n",
    "plt.hist(n_edges_prodigy - (n_nodes_prodigy - 1), alpha=0.5, bins=20, label='Prodigy')\n",
    "plt.hist(n_edges_uncons - (n_nodes_uncons - 1), alpha=0.5, bins=20, label='Uncons')\n",
    "plt.legend()\n",
    "plt.title(\"Distance from star\", fontsize=18)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# plt.hist(n_triangles_greedy[np.where(n_triangles_greedy < 500)[0]], alpha=0.5, bins=20, label='Greedy')\n",
    "# plt.hist(n_triangles_loss[np.where(n_triangles_loss < 500)[0]], alpha=0.5, bins=20, label='Loss')\n",
    "plt.hist(n_triangles_greedy, alpha=0.5, bins=20, label='Greedy')\n",
    "plt.hist(n_triangles_loss, alpha=0.5, bins=20, label='Loss')\n",
    "plt.hist(n_triangles_prodigy, alpha=0.5, bins=20, label='Prodigy')\n",
    "plt.hist(n_triangles_uncons, alpha=0.5, bins=20, label='Uncons')\n",
    "plt.legend()\n",
    "#plt.xlim([0,20])\n",
    "plt.title(\"Distribution of the number of triangles\", fontsize=18)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "adjs = torch.zeros(len(gen_graph_list_greedy), configt.data.max_node_num, configt.data.max_node_num)\n",
    "for i, G in enumerate(gen_graph_list_greedy):\n",
    "    nG = G.number_of_nodes()\n",
    "    adjs[i, :nG, :nG] = torch.tensor(nx.adjacency_matrix(G).todense())\n",
    "xs = torch.zeros (len(gen_graph_list_greedy), configt.data.max_node_num, configt.data.max_feat_num)\n",
    "satisfies(xs, adjs, constraint_config).sum().item()/len(adjs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "slideshow": {
     "slide_type": ""
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "adjs = torch.zeros(len(gen_graph_list_loss), configt.data.max_node_num, configt.data.max_node_num)\n",
    "for i, G in enumerate(gen_graph_list_loss):\n",
    "    nG = G.number_of_nodes()\n",
    "    adjs[i, :nG, :nG] = torch.tensor(nx.adjacency_matrix(G).todense())\n",
    "xs = torch.zeros (len(gen_graph_list_loss), configt.data.max_node_num, configt.data.max_feat_num)\n",
    "satisfies(xs, adjs, constraint_config).sum().item()/len(adjs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "adjs = torch.zeros(len(gen_graph_list_prodigy), configt.data.max_node_num, configt.data.max_node_num)\n",
    "for i, G in enumerate(gen_graph_list_prodigy):\n",
    "    nG = G.number_of_nodes()\n",
    "    adjs[i, :nG, :nG] = torch.tensor(nx.adjacency_matrix(G).todense())\n",
    "xs = torch.zeros (len(gen_graph_list_prodigy), configt.data.max_node_num, configt.data.max_feat_num)\n",
    "satisfies(xs, adjs, constraint_config).sum().item()/len(adjs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "adjs = torch.zeros(len(gen_graph_list_uncons), configt.data.max_node_num, configt.data.max_node_num)\n",
    "for i, G in enumerate(gen_graph_list_uncons):\n",
    "    nG = G.number_of_nodes()\n",
    "    adjs[i, :nG, :nG] = torch.tensor(nx.adjacency_matrix(G).todense())\n",
    "xs = torch.zeros (len(gen_graph_list_uncons), configt.data.max_node_num, configt.data.max_feat_num)\n",
    "satisfies(xs, adjs, constraint_config).sum().item()/len(adjs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "filter_test = True"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "results_dict_greedy = evaluate (gen_graph_list_greedy, configt, config, constraint_config, filter_test=filter_test)\n",
    "results_dict_loss = evaluate (gen_graph_list_loss, configt, config, constraint_config, filter_test=filter_test)\n",
    "results_dict_prodigy = evaluate (gen_graph_list_prodigy, configt, config, constraint_config, filter_test=filter_test)\n",
    "results_dict_uncons = evaluate (gen_graph_list_uncons, configt, config, constraint_config, filter_test=filter_test)\n",
    "results_dict_test = evaluate (test_graph_list, configt, config, constraint_config, filter_test=filter_test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if constraint_name == \"nedges\":\n",
    "    results_dict_greedy['nedges'] = n_edges_greedy.mean()\n",
    "    results_dict_loss['nedges'] = n_edges_loss.mean()\n",
    "    results_dict_prodigy['nedges'] = n_edges_prodigy.mean()\n",
    "    results_dict_uncons['nedges'] = n_edges_uncons.mean()\n",
    "    results_dict_test['nedges'] = n_edges_test.mean()\n",
    "elif constraint_name == \"ntriangles\":\n",
    "    results_dict_greedy['ntriangles'] = n_triangles_greedy.mean()\n",
    "    results_dict_loss['ntriangles'] = n_triangles_loss.mean()\n",
    "    results_dict_prodigy['ntriangles'] = n_triangles_prodigy.mean()\n",
    "    results_dict_uncons['ntriangles'] = n_triangles_uncons.mean()\n",
    "    results_dict_test['nedges'] = n_triangles_test.mean()\n",
    "elif constraint_name == \"degree\":\n",
    "    results_dict_greedy['max_degree'] = max_degree_greedy.mean()\n",
    "    results_dict_loss['max_degree'] = max_degree_loss.mean()\n",
    "    results_dict_prodigy['max_degree'] = max_degree_prodigy.mean()\n",
    "    results_dict_uncons['max_degree'] = max_degree_uncons.mean()\n",
    "    results_dict_test['max_degree'] = max_degree_test.mean()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "mmd_names = ['degree', 'cluster', 'orbit'] # 'spectral'\n",
    "mmd_avg_greedy = np.array([results_dict_greedy[f'{name}'] for name in mmd_names]).mean()\n",
    "mmd_avg_loss = np.array([results_dict_loss[f'{name}'] for name in mmd_names]).mean()\n",
    "mmd_avg_prodigy = np.array([results_dict_prodigy[f'{name}'] for name in mmd_names]).mean()\n",
    "mmd_avg_uncons = np.array([results_dict_uncons[f'{name}'] for name in mmd_names]).mean()\n",
    "mmd_avg_test = np.array([results_dict_test[f'{name}'] for name in mmd_names]).mean()\n",
    "results_dict_greedy['delta_mmd'] = mmd_avg_uncons - mmd_avg_greedy\n",
    "results_dict_loss['delta_mmd'] = mmd_avg_uncons - mmd_avg_loss\n",
    "results_dict_prodigy['delta_mmd'] = mmd_avg_uncons - mmd_avg_prodigy\n",
    "results_dict_test['delta_mmd'] = mmd_avg_uncons - mmd_avg_test\n",
    "results_dict_uncons['delta_mmd'] = 0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "all_results = [results_dict_greedy, results_dict_loss, results_dict_prodigy, results_dict_uncons, results_dict_test]\n",
    "df = pd.DataFrame(all_results, columns=results_dict_greedy.keys(), index=[\"Greedy\", \"Loss\", \"Prodigy\", \"Unconstrained\", \"Test\"])\n",
    "df.style.highlight_min(color='lightgreen', subset=[\"degree\", \"cluster\", \"orbit\", \"spectral\"])\\\n",
    "        .highlight_max(color=\"lightgreen\", subset=[\"constr_val\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "gdss_env",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
