{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import torch\n",
    "from easydict import EasyDict as edict\n",
    "\n",
    "from tqdm.notebook import trange\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "\n",
    "import sys\n",
    "sys.path.append('..')\n",
    "\n",
    "from losses import get_score_fn\n",
    "from solver_guidance import ReverseDiffusionPredictor, EulerMaruyamaPredictor, LangevinCorrector, NoneCorrector\n",
    "from utils.graph_utils import mask_adjs, mask_x\n",
    "from utils.logger import Logger, set_log, start_log, train_log, sample_log, check_log\n",
    "from utils.loader import load_ckpt, load_data, load_seed, load_device, load_model_from_ckpt, \\\n",
    "                         load_ema_from_ckpt, load_sde, load_yaml_config\n",
    "from utils.graph_utils import adjs_to_graphs, init_flags, quantize\n",
    "\n",
    "from parsers.config import get_config\n",
    "\n",
    "\n",
    "import time"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Data and preamble"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "device = load_device()\n",
    "device = [1]\n",
    "device_id = f'cuda:{device[0]}' if isinstance(device, list) else device"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "datasets = ['ego_small', 'enzymes', 'community_small', 'qm9', 'zinc250k']\n",
    "# datasets = ['zinc250k']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def sample(predictor_x, corrector_x,\n",
    "           predictor_adj, corrector_adj,\n",
    "           init_x=None, init_adj=None, flags=None):\n",
    "    with torch.no_grad():\n",
    "        # -------- Initial sample --------\n",
    "        if init_x is not None:\n",
    "            x = init_x.clone()\n",
    "        else:\n",
    "            x = predictor_x.sde.prior_sampling(shape_x).to(device_id)\n",
    "        if init_adj is not None:\n",
    "            adj = init_adj.clone()\n",
    "        else:\n",
    "            adj = predictor_adj.sde.prior_sampling_sym(shape_adj).to(device_id)\n",
    "        \n",
    "        if flags is None:\n",
    "            flags = init_flags(train_graph_list, configt).to(device_id)\n",
    "        x = mask_x(x, flags)\n",
    "        adj = mask_adjs(adj, flags)\n",
    "        diff_steps = predictor_adj.sde.N\n",
    "        timesteps = torch.linspace(predictor_adj.sde.T, eps, diff_steps, device=device_id)\n",
    "\n",
    "        # -------- Reverse diffusion process --------\n",
    "        for i in trange(0, (diff_steps), desc = '[Sampling]', position = 1, leave=False):\n",
    "            t = timesteps[i]\n",
    "            vec_t = torch.ones(shape_adj[0], device=t.device) * t\n",
    "\n",
    "            _x = x\n",
    "            x, x_mean = corrector_x.update_fn(x, adj, flags, vec_t)\n",
    "            adj, adj_mean = corrector_adj.update_fn(_x, adj, flags, vec_t)\n",
    "            if torch.any(torch.isnan(adj)):\n",
    "                return None\n",
    "\n",
    "            _x = x\n",
    "            x, x_mean = predictor_x.update_fn(x, adj, flags, vec_t)\n",
    "            adj, adj_mean = predictor_adj.update_fn(_x, adj, flags, vec_t)\n",
    "        print(' ')\n",
    "    samples_int = quantize(adj)\n",
    "    gen_graph_list = adjs_to_graphs(samples_int, True)\n",
    "    return gen_graph_list"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "N_vals = list(range(6))\n",
    "times = np.zeros(len(N_vals))\n",
    "\n",
    "results = []\n",
    "for dataset in datasets:\n",
    "    config_file = 'sample_' + dataset\n",
    "    seed = 0\n",
    "    config = get_config(config_file, seed)\n",
    "\n",
    "    # -------- Load checkpoint --------\n",
    "    ckpt_dict = load_ckpt(config, device)\n",
    "    configt = ckpt_dict['config']\n",
    "\n",
    "    load_seed(configt.seed)\n",
    "    train_graph_list, test_graph_list = load_data(configt, get_graph_list=True)\n",
    "\n",
    "    print(\"Mean number of edges:\", np.mean([g.number_of_edges() for g in train_graph_list]))\n",
    "    print(\"90 percentile number of edges:\", np.percentile([g.number_of_edges() for g in train_graph_list], 90))\n",
    "\n",
    "    log_folder_name, log_dir, _ = set_log(configt, is_train=False)\n",
    "    log_name = f\"{config.ckpt}-sample-guidance\"\n",
    "    logger = Logger(str(os.path.join(log_dir, f'{log_name}.log')), mode='a')\n",
    "    configt.data.init = \"deg\"\n",
    "\n",
    "    if not check_log(log_folder_name, log_name):\n",
    "        logger.log(f'{log_name}')\n",
    "        start_log(logger, configt)\n",
    "        train_log(logger, configt)\n",
    "    sample_log(logger, config)\n",
    "\n",
    "    # -------- Load models --------\n",
    "    model_x = load_model_from_ckpt(ckpt_dict['params_x'], ckpt_dict['x_state_dict'], device)\n",
    "    model_adj = load_model_from_ckpt(ckpt_dict['params_adj'], ckpt_dict['adj_state_dict'], device)\n",
    "\n",
    "    if config.sample.use_ema:\n",
    "        ema_x = load_ema_from_ckpt(model_x, ckpt_dict['ema_x'], configt.train.ema)\n",
    "        ema_adj = load_ema_from_ckpt(model_adj, ckpt_dict['ema_adj'], configt.train.ema)\n",
    "        \n",
    "        ema_x.copy_to(model_x.parameters())\n",
    "        ema_adj.copy_to(model_adj.parameters())\n",
    "\n",
    "    print(f'GEN SEED: {config.sample.seed}')\n",
    "    load_seed(config.sample.seed)\n",
    "\n",
    "    sde_x = load_sde(configt.sde.x)\n",
    "    sde_adj = load_sde(configt.sde.adj)\n",
    "    max_node_num  = configt.data.max_node_num\n",
    "\n",
    "    continuous = True\n",
    "\n",
    "    predictor = config.sampler.predictor\n",
    "    corrector = config.sampler.corrector\n",
    "\n",
    "    snr=config.sampler.snr\n",
    "    scale_eps=config.sampler.scale_eps\n",
    "    n_steps=config.sampler.n_steps\n",
    "\n",
    "    probability_flow = config.sample.probability_flow\n",
    "    eps = config.sample.eps\n",
    "    denoise = config.sample.noise_removal\n",
    "\n",
    "    shape_x = (configt.data.batch_size, max_node_num, configt.data.max_feat_num)\n",
    "    shape_adj = (configt.data.batch_size, max_node_num, max_node_num)\n",
    "\n",
    "    init_flags_iter = init_flags(train_graph_list, configt).to(device_id)\n",
    "\n",
    "    init_x = sde_x.prior_sampling(shape_x).to(device_id)\n",
    "    init_adj = sde_adj.prior_sampling_sym(shape_adj).to(device_id)\n",
    "\n",
    "    results_dataset = {\n",
    "        'Dataset': dataset,\n",
    "        'Max. Node Number': max_node_num,\n",
    "        'Corrector': corrector + f'(Steps {n_steps})' if corrector == \"Langevin\" else corrector\n",
    "    }\n",
    "\n",
    "    for n in N_vals:\n",
    "\n",
    "        if n == 0:\n",
    "            guidance_args = None\n",
    "        else:\n",
    "            guidance_config = load_yaml_config(f'config_guidance/constrained/nedges/greedy.yaml')\n",
    "            guidance_args = edict({'method': 'greedy', 'obj': guidance_config['obj'], **guidance_config[configt.data.data.lower()]})\n",
    "\n",
    "            guidance_args['n_traj'] = n\n",
    "\n",
    "        score_fn_x = get_score_fn(sde_x, model_x, train=False, continuous=continuous)\n",
    "        score_fn_adj = get_score_fn(sde_adj, model_adj, train=False, continuous=continuous)\n",
    "\n",
    "        predictor_fn = ReverseDiffusionPredictor if predictor=='Reverse' else EulerMaruyamaPredictor \n",
    "        corrector_fn = LangevinCorrector if corrector=='Langevin' else NoneCorrector\n",
    "\n",
    "        predictor_obj_x = predictor_fn('x', sde_x, score_fn_x, probability_flow)\n",
    "        corrector_obj_x = corrector_fn('x', sde_x, score_fn_x, snr, scale_eps, n_steps)\n",
    "\n",
    "        predictor_obj_adj = predictor_fn('adj', sde_adj, score_fn_adj, probability_flow, guidance_args=guidance_args)\n",
    "        corrector_obj_adj = corrector_fn('adj', sde_adj, score_fn_adj, snr, scale_eps, n_steps)\n",
    "\n",
    "        t_start = time.time()\n",
    "        sample(predictor_obj_x, corrector_obj_x, predictor_obj_adj, corrector_obj_adj, init_x, init_adj, init_flags_iter)\n",
    "\n",
    "        results_dataset.update({\n",
    "            f'Time N={n}': time.time() - t_start\n",
    "        })\n",
    "\n",
    "    results.append(results_dataset)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "results_df = pd.DataFrame(results).set_index('Dataset')\n",
    "\n",
    "results_df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(results_df.to_markdown())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "results_df[['Time N=0', 'Time N=1', 'Time N=2', 'Time N=3', 'Time N=4', 'Time N=5']].T.plot()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "digress-env",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
