{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import torch\n",
    "import networkx as nx\n",
    "import matplotlib.pyplot as plt\n",
    "from easydict import EasyDict as edict\n",
    "\n",
    "from tqdm.notebook import trange\n",
    "import numpy as np\n",
    "\n",
    "import sys\n",
    "sys.path.append('..')\n",
    "\n",
    "from parsers.config import get_config\n",
    "\n",
    "from losses import get_score_fn\n",
    "from solver_guidance import ReverseDiffusionPredictor, EulerMaruyamaPredictor, LangevinCorrector, NoneCorrector\n",
    "from utils.graph_utils import mask_adjs, mask_x\n",
    "from utils.logger import Logger, set_log, start_log, train_log, sample_log, check_log\n",
    "from utils.loader import load_ckpt, load_data, load_seed, load_device, load_model_from_ckpt, \\\n",
    "                         load_ema_from_ckpt, load_sde, load_yaml_config\n",
    "from utils.graph_utils import adjs_to_graphs, init_flags, quantize"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Data and preamble"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "device = load_device()\n",
    "device = [1]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "config_file = 'sample_ego_small'\n",
    "seed = 0\n",
    "config = get_config(config_file, seed)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# -------- Load checkpoint --------\n",
    "ckpt_dict = load_ckpt(config, device)\n",
    "configt = ckpt_dict['config']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "load_seed(configt.seed)\n",
    "train_graph_list, test_graph_list = load_data(configt, get_graph_list=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "log_folder_name, log_dir, _ = set_log(configt, is_train=False)\n",
    "log_name = f\"{config.ckpt}-sample-guidance\"\n",
    "logger = Logger(str(os.path.join(log_dir, f'{log_name}.log')), mode='a')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "configt.data.init = \"deg\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "n_edges_test = np.array([g.number_of_edges() for g in test_graph_list])\n",
    "n_triangles_test = np.array([sum(list(nx.triangles(g).values())) for g in test_graph_list])\n",
    "max_degrees_test = np.array([max([x[1] for x in g.degree()]) for g in test_graph_list])\n",
    "\n",
    "max_edges = np.percentile(n_edges_test, 10)\n",
    "max_triangles = np.percentile(n_triangles_test, 10)\n",
    "max_degree = np.percentile(max_degrees_test, 10)\n",
    "constraint_param_map = {'nedges': max_edges, 'ntriangles': max_triangles, 'degree': max_degree}\n",
    "max_edges, max_triangles, max_degree"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "constraint_name = 'nedges' # nedges, ntriangles, degree"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "satisfies_nedges = lambda g: g.number_of_edges() <= max_edges\n",
    "satisfies_degree = lambda g: max([x[1] for x in g.degree()]) <= max_degree\n",
    "satisfies_ntriangles = lambda g: sum(list(nx.triangles(g).values())) <= max_triangles\n",
    "satisfies_force_stars = lambda g: g.number_of_edges() == (g.number_of_nodes() - 1)\n",
    "\n",
    "satisfies_fn = eval('satisfies_' + constraint_name)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if config_file == \"sample_ego_small\":\n",
    "    valid_fn = lambda g: nx.is_isomorphic(g, nx.star_graph(g.number_of_nodes()-1))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if not check_log(log_folder_name, log_name):\n",
    "    logger.log(f'{log_name}')\n",
    "    start_log(logger, configt)\n",
    "    train_log(logger, configt)\n",
    "sample_log(logger, config)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# -------- Load models --------\n",
    "model_x = load_model_from_ckpt(ckpt_dict['params_x'], ckpt_dict['x_state_dict'], device)\n",
    "model_adj = load_model_from_ckpt(ckpt_dict['params_adj'], ckpt_dict['adj_state_dict'], device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if config.sample.use_ema:\n",
    "    ema_x = load_ema_from_ckpt(model_x, ckpt_dict['ema_x'], configt.train.ema)\n",
    "    ema_adj = load_ema_from_ckpt(model_adj, ckpt_dict['ema_adj'], configt.train.ema)\n",
    "    \n",
    "    ema_x.copy_to(model_x.parameters())\n",
    "    ema_adj.copy_to(model_adj.parameters())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(f'GEN SEED: {config.sample.seed}')\n",
    "load_seed(config.sample.seed)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Generation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sde_x = load_sde(configt.sde.x)\n",
    "sde_adj = load_sde(configt.sde.adj)\n",
    "max_node_num  = configt.data.max_node_num\n",
    "\n",
    "device_id = f'cuda:{device[0]}' if isinstance(device, list) else device"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "continuous = True\n",
    "\n",
    "predictor = config.sampler.predictor\n",
    "corrector = config.sampler.corrector\n",
    "\n",
    "snr=config.sampler.snr\n",
    "scale_eps=config.sampler.scale_eps\n",
    "n_steps=config.sampler.n_steps\n",
    "\n",
    "probability_flow = config.sample.probability_flow\n",
    "eps = config.sample.eps\n",
    "denoise = config.sample.noise_removal"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if configt.data.data in ['QM9']: # 'ZINC250k' gives OOM\n",
    "    shape_x = (10000, max_node_num, configt.data.max_feat_num)\n",
    "    shape_adj = (10000, max_node_num, max_node_num)\n",
    "else:\n",
    "    shape_x = (configt.data.batch_size, max_node_num, configt.data.max_feat_num)\n",
    "    shape_adj = (configt.data.batch_size, max_node_num, max_node_num)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "init_flags_iter = init_flags(train_graph_list, configt).to(device_id)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "init_x = sde_x.prior_sampling(shape_x).to(device_id)\n",
    "init_adj = sde_adj.prior_sampling_sym(shape_adj).to(device_id)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def sample(predictor_x, corrector_x,\n",
    "           predictor_adj, corrector_adj,\n",
    "           init_x=None, init_adj=None, flags=None):\n",
    "    with torch.no_grad():\n",
    "        # -------- Initial sample --------\n",
    "        if init_x is not None:\n",
    "            x = init_x.clone()\n",
    "        else:\n",
    "            x = predictor_x.sde.prior_sampling(shape_x).to(device_id)\n",
    "        if init_adj is not None:\n",
    "            adj = init_adj.clone()\n",
    "        else:\n",
    "            adj = predictor_adj.sde.prior_sampling_sym(shape_adj).to(device_id)\n",
    "        \n",
    "        if flags is None:\n",
    "            flags = init_flags(train_graph_list, configt).to(device_id)\n",
    "        x = mask_x(x, flags)\n",
    "        adj = mask_adjs(adj, flags)\n",
    "        diff_steps = predictor_adj.sde.N\n",
    "        timesteps = torch.linspace(predictor_adj.sde.T, eps, diff_steps, device=device_id)\n",
    "\n",
    "        # -------- Reverse diffusion process --------\n",
    "        for i in trange(0, (diff_steps), desc = '[Sampling]', position = 1, leave=False):\n",
    "            t = timesteps[i]\n",
    "            vec_t = torch.ones(shape_adj[0], device=t.device) * t\n",
    "\n",
    "            _x = x\n",
    "            x, x_mean = corrector_x.update_fn(x, adj, flags, vec_t)\n",
    "            adj, adj_mean = corrector_adj.update_fn(_x, adj, flags, vec_t)\n",
    "            if torch.any(torch.isnan(adj)):\n",
    "                return None\n",
    "\n",
    "            _x = x\n",
    "            x, x_mean = predictor_x.update_fn(x, adj, flags, vec_t)\n",
    "            adj, adj_mean = predictor_adj.update_fn(_x, adj, flags, vec_t)\n",
    "        print(' ')\n",
    "    samples_int = quantize(adj)\n",
    "    gen_graph_list = adjs_to_graphs(samples_int, True)\n",
    "    return gen_graph_list"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Greedy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "N_vals = [1] + list(range(2, 21, 2))\n",
    "N_vals"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "guidance_config = load_yaml_config(f'config_guidance/constrained/{constraint_name}/greedy.yaml')\n",
    "guidance_args = edict({'method': 'greedy', 'obj': guidance_config['obj'], **guidance_config[configt.data.data.lower()]})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "score_fn_x = get_score_fn(sde_x, model_x, train=False, continuous=continuous)\n",
    "score_fn_adj = get_score_fn(sde_adj, model_adj, train=False, continuous=continuous)\n",
    "\n",
    "predictor_fn = ReverseDiffusionPredictor if predictor=='Reverse' else EulerMaruyamaPredictor \n",
    "corrector_fn = LangevinCorrector if corrector=='Langevin' else NoneCorrector\n",
    "\n",
    "predictor_obj_x = predictor_fn('x', sde_x, score_fn_x, probability_flow)\n",
    "corrector_obj_x = corrector_fn('x', sde_x, score_fn_x, snr, scale_eps, n_steps)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "satisfies_greedy = {}\n",
    "valid_greedy = {}\n",
    "pct_one_node_greedy = {}\n",
    "\n",
    "for n in N_vals:\n",
    "    print(f'Running for N={n}')\n",
    "    guidance_args['n_traj'] = n\n",
    "    predictor_obj_adj = predictor_fn('adj', sde_adj, score_fn_adj, probability_flow, guidance_args=guidance_args)\n",
    "    corrector_obj_adj = corrector_fn('adj', sde_adj, score_fn_adj, snr, scale_eps, n_steps)\n",
    "\n",
    "    gen_graph_list_greedy = sample(predictor_obj_x, corrector_obj_x, predictor_obj_adj, corrector_obj_adj, init_x, init_adj, init_flags_iter)\n",
    "\n",
    "    satisfies_greedy[n] = np.mean([satisfies_fn(g) for g in gen_graph_list_greedy])\n",
    "    valid_greedy[n] = np.mean([valid_fn(g) for g in gen_graph_list_greedy]) if 'valid_fn' in locals() else None\n",
    "    pct_one_node_greedy[n] = np.mean([g.number_of_nodes() == 1 for g in gen_graph_list_greedy])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Loss"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "guidance_config = load_yaml_config(f'config_guidance/constrained/{constraint_name}/loss.yaml')\n",
    "guidance_args = edict({'method': 'loss', 'obj': guidance_config['obj'], **guidance_config[configt.data.data.lower()]})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "score_fn_x = get_score_fn(sde_x, model_x, train=False, continuous=continuous)\n",
    "score_fn_adj = get_score_fn(sde_adj, model_adj, train=False, continuous=continuous)\n",
    "\n",
    "predictor_fn = ReverseDiffusionPredictor if predictor=='Reverse' else EulerMaruyamaPredictor \n",
    "corrector_fn = LangevinCorrector if corrector=='Langevin' else NoneCorrector\n",
    "\n",
    "predictor_obj_x = predictor_fn('x', sde_x, score_fn_x, probability_flow)\n",
    "corrector_obj_x = corrector_fn('x', sde_x, score_fn_x, snr, scale_eps, n_steps)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "lr_vals = [1e-2, 1e-1, 5e-1, 1., 5., 10., 50., 100., 150., 200., 250., 300.]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "satisfies_loss = {}\n",
    "valid_loss = {}\n",
    "pct_one_node_loss = {}\n",
    "\n",
    "for lr in lr_vals:\n",
    "    print(f'Running for lr={lr}')\n",
    "    guidance_args['lr_guidance'] = lr\n",
    "    predictor_obj_adj = predictor_fn('adj', sde_adj, score_fn_adj, probability_flow, guidance_args=guidance_args)\n",
    "    corrector_obj_adj = corrector_fn('adj', sde_adj, score_fn_adj, snr, scale_eps, n_steps)\n",
    "\n",
    "    gen_graph_list_loss = sample(predictor_obj_x, corrector_obj_x, predictor_obj_adj, corrector_obj_adj, init_x, init_adj, init_flags_iter)\n",
    "    if gen_graph_list_loss is None:\n",
    "        satisfies_loss[lr] = 0.0\n",
    "        valid_loss[lr] = 0.0\n",
    "        pct_one_node_loss[lr] = 0.0\n",
    "        continue\n",
    "\n",
    "    satisfies_loss[lr] = np.mean([satisfies_fn(g) for g in gen_graph_list_loss])\n",
    "    valid_loss[lr] = np.mean([valid_fn(g) for g in gen_graph_list_loss]) if 'valid_fn' in locals() else None\n",
    "    pct_one_node_loss[lr] = np.mean([g.number_of_nodes() == 1 for g in gen_graph_list_loss])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Zero order optimization"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "guidance_config = load_yaml_config(f'config_guidance/constrained/{constraint_name}/zero.yaml')\n",
    "guidance_args = edict({'method': 'zero', 'obj': guidance_config['obj'], **guidance_config[configt.data.data.lower()]})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "score_fn_x = get_score_fn(sde_x, model_x, train=False, continuous=continuous)\n",
    "score_fn_adj = get_score_fn(sde_adj, model_adj, train=False, continuous=continuous)\n",
    "\n",
    "predictor_fn = ReverseDiffusionPredictor if predictor=='Reverse' else EulerMaruyamaPredictor \n",
    "corrector_fn = LangevinCorrector if corrector=='Langevin' else NoneCorrector\n",
    "\n",
    "predictor_obj_x = predictor_fn('x', sde_x, score_fn_x, probability_flow)\n",
    "corrector_obj_x = corrector_fn('x', sde_x, score_fn_x, snr, scale_eps, n_steps)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "satisfies_zero = {}\n",
    "valid_zero = {}\n",
    "pct_one_node_zero = {}\n",
    "\n",
    "for n in N_vals:\n",
    "    print(f'Running for N={n}')\n",
    "    guidance_args['n_traj'] = n\n",
    "    predictor_obj_adj = predictor_fn('adj', sde_adj, score_fn_adj, probability_flow, guidance_args=guidance_args)\n",
    "    corrector_obj_adj = corrector_fn('adj', sde_adj, score_fn_adj, snr, scale_eps, n_steps)\n",
    "\n",
    "    gen_graph_list_zero = sample(predictor_obj_x, corrector_obj_x, predictor_obj_adj, corrector_obj_adj, init_x, init_adj, init_flags_iter)\n",
    "    if gen_graph_list_zero is None:\n",
    "        satisfies_zero[n] = 0.0\n",
    "        valid_zero[n] = 0.0\n",
    "        pct_one_node_zero[n] = 0.0\n",
    "        continue\n",
    "\n",
    "    satisfies_zero[n] = np.mean([satisfies_fn(g) for g in gen_graph_list_zero])\n",
    "    valid_zero[n] = np.mean([valid_fn(g) for g in gen_graph_list_zero]) if 'valid_fn' in locals() else None\n",
    "    pct_one_node_zero[n] = np.mean([g.number_of_nodes() == 1 for g in gen_graph_list_zero])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Evaluation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, ax1 = plt.subplots()\n",
    "\n",
    "color = 'tab:red'\n",
    "ax1.set_xlabel('$\\\\lambda$')\n",
    "ax1.set_ylabel('Val$_{\\\\mathcal{C}}$')\n",
    "ax1.semilogx(lr_vals, satisfies_loss.values())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, ax1 = plt.subplots()\n",
    "\n",
    "color = 'tab:red'\n",
    "ax1.set_xlabel('$N$')\n",
    "ax1.set_ylabel('Val$_{\\\\mathcal{C}}$')\n",
    "ax1.plot(N_vals, satisfies_greedy.values())\n",
    "ax1.set_xticks(N_vals)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, ax1 = plt.subplots()\n",
    "\n",
    "color = 'tab:red'\n",
    "ax1.set_xlabel('$N$')\n",
    "ax1.set_ylabel('Val$_{\\\\mathcal{C}}$')\n",
    "ax1.plot(N_vals, satisfies_zero.values())\n",
    "ax1.set_xticks(N_vals)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(\"$N$ | \" + \"|\".join([str(x) for x in N_vals]))\n",
    "print(\"---\" + \" | --\"*len(N_vals))\n",
    "print(\"GGDiff-C | \" + \"|\".join([str(round(x,3)) for x in satisfies_greedy.values()]))\n",
    "print(\"GGDiff-Z | \" + \"|\".join([str(round(x,3)) for x in satisfies_zero.values()]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(\"$\\\\lambda$ | \" + \"|\".join([str(x) for x in lr_vals]))\n",
    "print(\"--------\" + \" | --\"*len(lr_vals))\n",
    "print(\"GGDiff-G | \" + \"|\".join([str(round(100*x,2)) for x in satisfies_loss.values()]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pickle"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "with open('results/ablation/incomplete_constrained.pkl', 'wb') as f:\n",
    "    pickle.dump({\n",
    "        'N_vals': N_vals, 'lr_vals': lr_vals,\n",
    "        'satisfies_loss': satisfies_loss,\n",
    "        'satisfies_greedy': satisfies_greedy,\n",
    "        'satisfies_zero': satisfies_zero}, f\n",
    "    )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "digress-env",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
