{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import torch\n",
    "import networkx as nx\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "from easydict import EasyDict as edict\n",
    "\n",
    "from losses import get_score_fn\n",
    "from solver_guidance import ReverseDiffusionPredictor, Predictor, LangevinCorrector, NoneCorrector\n",
    "from utils.graph_utils import mask_adjs, mask_x, gen_noise\n",
    "from tqdm.notebook import trange\n",
    "\n",
    "from parsers.config import get_config\n",
    "\n",
    "from utils.logger import Logger, set_log, start_log, train_log, sample_log, check_log\n",
    "from utils.loader import load_ckpt, load_data, load_seed, load_device, load_model_from_ckpt, \\\n",
    "                         load_ema_from_ckpt, load_sde, load_yaml_config\n",
    "from utils.graph_utils import adjs_to_graphs, init_flags, quantize, \\\n",
    "                              compute_group_assignments, count_across_community_edges, est_p_intra_inter, \\\n",
    "                              is_sbm_graph\n",
    "\n",
    "from losses_guidance import compute_dp1, compute_dp2, compute_nodedp1, compute_nodedp2"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Data and preamble"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "device = load_device()\n",
    "device = [0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "config_file = 'sample_community_small' # sample_sbm sample_community_small\n",
    "seed = 0\n",
    "config = get_config(config_file, seed)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# -------- Load checkpoint --------\n",
    "ckpt_dict = load_ckpt(config, device)\n",
    "configt = ckpt_dict['config']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "load_seed(configt.seed)\n",
    "train_graph_list, test_graph_list = load_data(configt, get_graph_list=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "log_folder_name, log_dir, _ = set_log(configt, is_train=False)\n",
    "log_name = f\"{config.ckpt}-sample-guidance\"\n",
    "logger = Logger(str(os.path.join(log_dir, f'{log_name}.log')), mode='a')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "f, axs = plt.subplots(5,5, figsize=(10, 10))\n",
    "for ax in axs.ravel():\n",
    "    nx.draw(test_graph_list[0], with_labels=False, ax=ax, node_size=10)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if not check_log(log_folder_name, log_name):\n",
    "    logger.log(f'{log_name}')\n",
    "    start_log(logger, configt)\n",
    "    train_log(logger, configt)\n",
    "sample_log(logger, config)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# -------- Load models --------\n",
    "model_x = load_model_from_ckpt(ckpt_dict['params_x'], ckpt_dict['x_state_dict'], device)\n",
    "model_adj = load_model_from_ckpt(ckpt_dict['params_adj'], ckpt_dict['adj_state_dict'], device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if config.sample.use_ema:\n",
    "    ema_x = load_ema_from_ckpt(model_x, ckpt_dict['ema_x'], configt.train.ema)\n",
    "    ema_adj = load_ema_from_ckpt(model_adj, ckpt_dict['ema_adj'], configt.train.ema)\n",
    "    \n",
    "    ema_x.copy_to(model_x.parameters())\n",
    "    ema_adj.copy_to(model_adj.parameters())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(f'GEN SEED: {config.sample.seed}')\n",
    "load_seed(config.sample.seed)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Generation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sde_x = load_sde(configt.sde.x)\n",
    "sde_adj = load_sde(configt.sde.adj)\n",
    "max_node_num  = configt.data.max_node_num\n",
    "\n",
    "device_id = f'cuda:{device[0]}' if isinstance(device, list) else device\n",
    "print(device_id)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "continuous = True\n",
    "\n",
    "predictor = config.sampler.predictor\n",
    "corrector = config.sampler.corrector\n",
    "\n",
    "snr=config.sampler.snr\n",
    "scale_eps=config.sampler.scale_eps\n",
    "n_steps=config.sampler.n_steps\n",
    "\n",
    "probability_flow = config.sample.probability_flow\n",
    "eps = config.sample.eps\n",
    "denoise = config.sample.noise_removal"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if configt.data.data in ['QM9', 'ZINC250k']: # CHANGED\n",
    "    shape_x = (10000, max_node_num, configt.data.max_feat_num)\n",
    "    shape_adj = (10000, max_node_num, max_node_num)\n",
    "else:\n",
    "    shape_x = (configt.data.batch_size, max_node_num, configt.data.max_feat_num)\n",
    "    shape_adj = (configt.data.batch_size, max_node_num, max_node_num)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "init_flags_iter = init_flags(train_graph_list, configt).to(device_id)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "init_x = sde_x.prior_sampling(shape_x).to(device_id)\n",
    "init_adj = sde_adj.prior_sampling_sym(shape_adj).to(device_id)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Repeated here for cleaner code\n",
    "class EulerMaruyamaPredictor(Predictor):\n",
    "    def __init__(self, obj, sde, score_fn, probability_flow=False, guidance_args=None):\n",
    "        super().__init__(sde, score_fn, probability_flow)\n",
    "        self.obj = obj\n",
    "\n",
    "        self.guidance_args = guidance_args\n",
    "\n",
    "        self.Z = None\n",
    "\n",
    "    def guidance(self, x, adj, flags, t, is_adj):\n",
    "        obj = adj if is_adj else x\n",
    "\n",
    "        dt = -1. / self.rsde.N\n",
    "        timestep = (t * (self.rsde.N - 1) / self.rsde.T).long()\n",
    "\n",
    "        loss_fn = eval(self.guidance_args.loss_fn)\n",
    "        loss_kwargs = self.guidance_args.get('loss_kwargs', {})\n",
    "\n",
    "        if self.obj == 'adj' and self.Z is None:\n",
    "            if self.guidance_args.method_Z == \"communities\":\n",
    "                adj0 = self.sde.obj0estimation(adj, self.score_fn(x, adj, flags, t), timestep)\n",
    "                adj0 = mask_adjs(adj0, flags)\n",
    "                gs = adjs_to_graphs(quantize(adj0), True)\n",
    "                Zs = torch.zeros((adj.shape[0], adj.shape[1], self.guidance_args.get('n_com', 2))).to(adj.device)\n",
    "                for g, graph in enumerate(gs):\n",
    "                    Zs[g,:graph.number_of_nodes(),:] = compute_group_assignments(graph, self.guidance_args.get('n_com', 2))\n",
    "                self.Z = Zs.permute(0,2,1)\n",
    "            elif self.guidance_args.method_Z == \"random\":\n",
    "                idxs_com = np.random.randint(0, self.guidance_args.get('n_com', 2), size=(adj.shape[0], adj.shape[1]))\n",
    "                Zs = torch.nn.functional.one_hot(torch.tensor(idxs_com), num_classes=self.guidance_args.get('n_com', 2)).float().to(adj.device)\n",
    "                self.Z = Zs.permute(0,2,1)\n",
    "        loss_kwargs['Z'] = self.Z.clone()\n",
    "        \n",
    "        if self.guidance_args.method == 'greedy':\n",
    "            n_traj = self.guidance_args.n_traj\n",
    "\n",
    "            drift, diffusion = self.rsde.sde(x, adj, flags, t, is_adj=is_adj)\n",
    "            \n",
    "            obj_mean = obj + drift * dt\n",
    "\n",
    "            losses = torch.zeros(n_traj, obj.shape[0])\n",
    "            obj_hats = []\n",
    "            for i in range(n_traj):\n",
    "                z = gen_noise(obj, flags, sym=is_adj)\n",
    "                obj_hat = obj_mean.clone() + diffusion[:, None, None] * np.sqrt(-dt) * z\n",
    "                obj_hats.append(obj_hat)\n",
    "                \n",
    "                score_i = self.score_fn(x, obj_hat, flags, t+dt) if is_adj else self.score_fn(obj_hat, adj, flags, t+dt)\n",
    "\n",
    "                obj0hat = self.sde.obj0estimation(obj_hat, score_i, timestep)\n",
    "                obj0hat_masked = mask_adjs(obj0hat, flags) if is_adj else mask_x(obj0hat, flags)\n",
    "                \n",
    "                losses[i,:] = loss_fn(obj0hat_masked, **loss_kwargs)\n",
    "\n",
    "            losses_expanded = torch.argmin(losses, dim=0).view(1, obj.shape[0], 1, 1).expand(1, obj.shape[0], obj.shape[1], obj.shape[2]).to(obj.device)\n",
    "\n",
    "            return torch.gather(torch.stack(obj_hats, dim=0), 0, losses_expanded).squeeze(0), obj_mean\n",
    "\n",
    "        elif self.guidance_args.method == 'zero':\n",
    "            n_traj = self.guidance_args.n_traj\n",
    "\n",
    "            drift, diffusion = self.rsde.sde(x, adj, flags, t, is_adj=is_adj)\n",
    "            \n",
    "            obj_mean = obj + drift * dt\n",
    "\n",
    "            z = gen_noise(obj, flags, sym=is_adj)\n",
    "            obj = obj.clone() + diffusion[:, None, None] * np.sqrt(-dt) * z\n",
    "\n",
    "            score_no_noise = self.score_fn(x, obj.clone(), flags, t+dt) if is_adj else self.score_fn(obj.clone(), adj, flags, t)\n",
    "            obj0hat_no_noise = self.sde.obj0estimation(obj.clone(), score_no_noise, timestep)\n",
    "            obj0hat_masked_no_noise = mask_adjs(obj0hat_no_noise, flags) if is_adj else mask_x(obj0hat_no_noise, flags)\n",
    "            no_noise_loss = loss_fn(obj0hat_masked_no_noise, **loss_kwargs)\n",
    "\n",
    "            losses = torch.zeros(n_traj, obj.shape[0], device=obj.device)\n",
    "            noise_directions = []\n",
    "            for i in range(n_traj):\n",
    "                z = gen_noise(obj, flags, sym=is_adj)\n",
    "                obj_hat = obj.clone() + self.guidance_args.delta * z\n",
    "                noise_directions.append(z)\n",
    "                \n",
    "                score_i = self.score_fn(x, obj_hat, flags, t+dt) if is_adj else self.score_fn(obj_hat, adj, flags, t+dt)\n",
    "\n",
    "                obj0hat = self.sde.obj0estimation(obj_hat, score_i, timestep)\n",
    "                obj0hat_masked = mask_adjs(obj0hat, flags) if is_adj else mask_x(obj0hat, flags)\n",
    "                \n",
    "                losses[i,:] = loss_fn(obj0hat_masked, **loss_kwargs)\n",
    "\n",
    "            weights = (losses - no_noise_loss[None,:]) / self.guidance_args.delta # (diffusion[None, :] * np.sqrt(-dt))\n",
    "            directions = torch.stack(noise_directions, dim=0)\n",
    "            weighted_directions = (weights[:,:,None,None] * directions).mean(dim=0)\n",
    "\n",
    "            if self.guidance_args.clip_method == \"clip\":\n",
    "                # Gradient clipping\n",
    "                weighted_directions = torch.clamp(weighted_directions, -self.guidance_args.clip, self.guidance_args.clip)\n",
    "\n",
    "                obj = obj - self.guidance_args.lr_zero * weighted_directions\n",
    "            elif self.guidance_args.clip_method == \"norm\":\n",
    "                norm_grad_step = torch.linalg.norm(weighted_directions, dim=(1,2)) / (weighted_directions.shape[1] * weighted_directions.shape[2])\n",
    "                norm_grad_step = torch.where(norm_grad_step < 1e-7, torch.ones_like(norm_grad_step), norm_grad_step)\n",
    "\n",
    "                obj = obj - self.guidance_args.lr_zero * weighted_directions / norm_grad_step[:,None,None]\n",
    "            else:\n",
    "                obj = obj - self.guidance_args.lr_zero * weighted_directions\n",
    "            return obj, obj_mean\n",
    "        \n",
    "        elif self.guidance_args.method == 'loss':\n",
    "\n",
    "            with torch.enable_grad():\n",
    "\n",
    "                obj.requires_grad = True\n",
    "\n",
    "                score = self.score_fn(x, obj, flags, t) if is_adj else self.score_fn(obj, adj, flags, t)\n",
    "\n",
    "                obj0hat = self.sde.obj0estimation(obj, score, timestep)\n",
    "                obj0hat_masked = mask_adjs(obj0hat, flags) if is_adj else mask_x(obj0hat, flags)\n",
    "\n",
    "                loss = loss_fn(obj0hat_masked, **loss_kwargs).mean()\n",
    "\n",
    "                loss.backward()\n",
    "\n",
    "                obj_grad = obj.grad.detach().clone()\n",
    "                obj.grad = None\n",
    "\n",
    "            drift, diffusion = self.sde.sde(adj, t)\n",
    "            drift = drift - diffusion[:, None, None] ** 2 * score * (0.5 if self.rsde.probability_flow else 1.)\n",
    "            # -------- Set the diffusion function to zero for ODEs. --------\n",
    "            diffusion = 0. if self.rsde.probability_flow else diffusion\n",
    "\n",
    "            obj_mean = obj + drift * dt\n",
    "\n",
    "            z = gen_noise(obj, flags, sym=is_adj)\n",
    "            obj = obj_mean + diffusion[:, None, None] * np.sqrt(-dt) * z\n",
    "\n",
    "            if self.guidance_args.lr_guidance_method == 'adaptive':\n",
    "                obj -= self.guidance_args.lr_guidance / torch.abs(loss) * obj_grad\n",
    "            else:\n",
    "                obj -= self.guidance_args.lr_guidance * obj_grad\n",
    "\n",
    "            return obj, obj_mean\n",
    "        else:\n",
    "            raise NotImplementedError(f\"guidance method {self.guidance_args.method} not yet supported.\")\n",
    "\n",
    "            \n",
    "    def update_fn(self, x, adj, flags, t):\n",
    "        dt = -1. / self.rsde.N\n",
    "        timestep = (t[0] * (self.rsde.N - 1) / self.rsde.T).long()\n",
    "\n",
    "        var = x if self.obj == 'x' else adj\n",
    "\n",
    "        cond_guidance = self.guidance_args is not None and timestep > 0\n",
    "        if self.guidance_args is not None and 'from_t' in self.guidance_args:\n",
    "            cond_guidance = cond_guidance and timestep.item() < self.guidance_args.from_t\n",
    "            \n",
    "        if cond_guidance:\n",
    "            var, var_mean = self.guidance(x, adj, flags, t, is_adj=self.obj == 'adj')\n",
    "        else:\n",
    "            z = gen_noise(var, flags, sym=self.obj == 'adj')\n",
    "            drift, diffusion = self.rsde.sde(x, adj, flags, t, is_adj=self.obj == 'adj')\n",
    "            var_mean = var + drift * dt\n",
    "            var = var_mean + diffusion[:, None, None] * np.sqrt(-dt) * z\n",
    "        return var, var_mean\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def sample(predictor_x, corrector_x,\n",
    "           predictor_adj, corrector_adj,\n",
    "           init_x=None, init_adj=None, flags=None):\n",
    "    with torch.no_grad():\n",
    "        # -------- Initial sample --------\n",
    "        if init_x is not None:\n",
    "            x = init_x.clone()\n",
    "        else:\n",
    "            x = predictor_x.sde.prior_sampling(shape_x).to(device_id)\n",
    "        if init_adj is not None:\n",
    "            adj = init_adj.clone()\n",
    "        else:\n",
    "            adj = predictor_adj.sde.prior_sampling_sym(shape_adj).to(device_id)\n",
    "        \n",
    "        if flags is None:\n",
    "            flags = init_flags(train_graph_list, configt).to(device_id)\n",
    "        x = mask_x(x, flags)\n",
    "        adj = mask_adjs(adj, flags)\n",
    "        diff_steps = predictor_adj.sde.N\n",
    "        timesteps = torch.linspace(predictor_adj.sde.T, eps, diff_steps, device=device_id)\n",
    "\n",
    "        # -------- Reverse diffusion process --------\n",
    "        for i in trange(0, (diff_steps), desc = '[Sampling]', position = 1, leave=False):\n",
    "            t = timesteps[i]\n",
    "            vec_t = torch.ones(shape_adj[0], device=t.device) * t\n",
    "\n",
    "            _x = x\n",
    "            x, x_mean = corrector_x.update_fn(x, adj, flags, vec_t)\n",
    "            adj, adj_mean = corrector_adj.update_fn(_x, adj, flags, vec_t)\n",
    "\n",
    "            _x = x\n",
    "            x, x_mean = predictor_x.update_fn(x, adj, flags, vec_t)\n",
    "            adj, adj_mean = predictor_adj.update_fn(_x, adj, flags, vec_t)\n",
    "        print(' ')\n",
    "    samples_int = quantize(adj)\n",
    "    gen_graph_list = adjs_to_graphs(samples_int, True)\n",
    "    return gen_graph_list"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Greedy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "method_Z = \"random\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "guidance_config = load_yaml_config(f'config_guidance/fairness/{method_Z}/greedy.yaml')\n",
    "guidance_args = edict({'method': 'greedy', 'obj': guidance_config['obj'], **guidance_config[configt.data.data.lower()]})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "score_fn_x = get_score_fn(sde_x, model_x, train=False, continuous=continuous)\n",
    "score_fn_adj = get_score_fn(sde_adj, model_adj, train=False, continuous=continuous)\n",
    "\n",
    "predictor_fn = ReverseDiffusionPredictor if predictor=='Reverse' else EulerMaruyamaPredictor \n",
    "corrector_fn = LangevinCorrector if corrector=='Langevin' else NoneCorrector\n",
    "\n",
    "predictor_obj_x = predictor_fn('x', sde_x, score_fn_x, probability_flow)\n",
    "corrector_obj_x = corrector_fn('x', sde_x, score_fn_x, snr, scale_eps, n_steps)\n",
    "\n",
    "predictor_obj_adj = predictor_fn('adj', sde_adj, score_fn_adj, probability_flow, guidance_args=guidance_args)\n",
    "corrector_obj_adj = corrector_fn('adj', sde_adj, score_fn_adj, snr, scale_eps, n_steps)\n",
    "\n",
    "gen_graph_list_greedy = sample(predictor_obj_x, corrector_obj_x, predictor_obj_adj, corrector_obj_adj, init_x, init_adj, init_flags_iter)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "n_nodes_greedy = np.array([g.number_of_nodes() for g in gen_graph_list_greedy])\n",
    "n_edges_greedy = np.array([g.number_of_edges() for g in gen_graph_list_greedy])\n",
    "n_triangles_greedy = np.array([sum(list(nx.triangles(g).values())) for g in gen_graph_list_greedy])\n",
    "max_degrees_greedy = np.array([max([x[1] for x in g.degree()]) for g in gen_graph_list_greedy])\n",
    "n_edges_greedy.mean(), n_edges_greedy.std()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def compute_group_assignments_tensor(adj, n_com):\n",
    "    gs = adjs_to_graphs(adj, True)\n",
    "    Zs = torch.zeros((adj.shape[0], adj.shape[1], n_com)).to(adj.device)\n",
    "    for g, graph in enumerate(gs):\n",
    "        Zs[g,:graph.number_of_nodes(),:] = compute_group_assignments(graph, n_com)\n",
    "    return Zs.permute(0,2,1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "nx.draw(gen_graph_list_greedy[3])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "adjs_greedy = torch.zeros(len(gen_graph_list_greedy), configt.data.max_node_num, configt.data.max_node_num)\n",
    "for i, G in enumerate(gen_graph_list_greedy):\n",
    "    nG = G.number_of_nodes()\n",
    "    adjs_greedy[i, :nG, :nG] = torch.tensor(nx.adjacency_matrix(G).todense())\n",
    "xs_greedy = torch.zeros(len(gen_graph_list_greedy), configt.data.max_node_num, configt.data.max_feat_num)\n",
    "if guidance_args.method_Z == \"communities\":\n",
    "    Zs_greedy = compute_group_assignments_tensor(adjs_greedy.to(device_id), guidance_args.n_com).to(device_id)\n",
    "else:\n",
    "    Zs_greedy = predictor_obj_adj.Z.clone().to(device_id)\n",
    "dp1_greedy = compute_dp1(adjs_greedy.to(device_id), Zs_greedy).mean().item()\n",
    "dp2_greedy = compute_dp2(adjs_greedy.to(device_id), Zs_greedy).mean().item()\n",
    "dp2_greedy_std = compute_dp2(adjs_greedy.to(device_id), Zs_greedy).std().item()\n",
    "nodedp1_greedy = compute_nodedp1(adjs_greedy.to(device_id), Zs_greedy).mean().item()\n",
    "nodedp2_greedy = compute_nodedp2(adjs_greedy.to(device_id), Zs_greedy).mean().item()\n",
    "nodedp2_greedy_std = compute_nodedp2(adjs_greedy.to(device_id), Zs_greedy).std().item()\n",
    "across_comm_edges_greedy = count_across_community_edges(adjs_greedy.to(device_id), Zs_greedy).mean().item()\n",
    "\n",
    "dp1_greedy, dp2_greedy, nodedp1_greedy, nodedp2_greedy, across_comm_edges_greedy"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Loss"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "guidance_config = load_yaml_config(f'config_guidance/fairness/{method_Z}/loss.yaml')\n",
    "guidance_args = edict({'method': 'loss', 'obj': guidance_config['obj'], **guidance_config[configt.data.data.lower()]})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "score_fn_x = get_score_fn(sde_x, model_x, train=False, continuous=continuous)\n",
    "score_fn_adj = get_score_fn(sde_adj, model_adj, train=False, continuous=continuous)\n",
    "\n",
    "predictor_fn = ReverseDiffusionPredictor if predictor=='Reverse' else EulerMaruyamaPredictor \n",
    "corrector_fn = LangevinCorrector if corrector=='Langevin' else NoneCorrector\n",
    "\n",
    "predictor_obj_x = predictor_fn('x', sde_x, score_fn_x, probability_flow)\n",
    "corrector_obj_x = corrector_fn('x', sde_x, score_fn_x, snr, scale_eps, n_steps)\n",
    "\n",
    "predictor_obj_adj = predictor_fn('adj', sde_adj, score_fn_adj, probability_flow, guidance_args=guidance_args)\n",
    "corrector_obj_adj = corrector_fn('adj', sde_adj, score_fn_adj, snr, scale_eps, n_steps)\n",
    "\n",
    "gen_graph_list_loss = sample(predictor_obj_x, corrector_obj_x, predictor_obj_adj, corrector_obj_adj, init_x, init_adj, init_flags_iter)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "n_nodes_loss = np.array([g.number_of_nodes() for g in gen_graph_list_loss])\n",
    "n_edges_loss = np.array([g.number_of_edges() for g in gen_graph_list_loss])\n",
    "n_triangles_loss = np.array([sum(list(nx.triangles(g).values())) for g in gen_graph_list_loss])\n",
    "max_degrees_loss = np.array([max([x[1] for x in g.degree()]) for g in gen_graph_list_loss])\n",
    "n_edges_loss.mean(), n_edges_loss.std()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "adjs_loss = torch.zeros(len(gen_graph_list_loss), configt.data.max_node_num, configt.data.max_node_num)\n",
    "for i, G in enumerate(gen_graph_list_loss):\n",
    "    nG = G.number_of_nodes()\n",
    "    adjs_loss[i, :nG, :nG] = torch.tensor(nx.adjacency_matrix(G).todense())\n",
    "xs_loss = torch.zeros (len(gen_graph_list_loss), configt.data.max_node_num, configt.data.max_feat_num)\n",
    "if guidance_args.method_Z == \"communities\":\n",
    "    Zs_loss = compute_group_assignments_tensor(adjs_loss.to(device_id), guidance_args.n_com).to(device_id)\n",
    "else:\n",
    "    Zs_loss = predictor_obj_adj.Z.clone().to(device_id)\n",
    "dp1_loss = compute_dp1(adjs_loss.to(device_id), Zs_loss).mean().item()\n",
    "dp2_loss = compute_dp2(adjs_loss.to(device_id), Zs_loss).mean().item()\n",
    "dp2_loss_std = compute_dp2(adjs_loss.to(device_id), Zs_loss).std().item()\n",
    "nodedp1_loss = compute_nodedp1(adjs_loss.to(device_id), Zs_loss).mean().item()\n",
    "nodedp2_loss = compute_nodedp2(adjs_loss.to(device_id), Zs_loss).mean().item()\n",
    "nodedp2_loss_std = compute_nodedp2(adjs_loss.to(device_id), Zs_loss).std().item()\n",
    "across_comm_edges_loss = count_across_community_edges(adjs_loss.to(device_id), Zs_loss).mean().item()\n",
    "\n",
    "dp1_loss, dp2_loss, nodedp1_loss, nodedp2_loss, across_comm_edges_loss"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for i in range(5):\n",
    "    plt.subplot(5, 5, i+1)\n",
    "    nx.draw(gen_graph_list_loss[np.random.randint(len(gen_graph_list_loss))], with_labels=False, node_size=10)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Zero order optimization"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "guidance_config = load_yaml_config(f'config_guidance/fairness/{method_Z}/zero.yaml')\n",
    "guidance_args = edict({'method': 'zero', 'obj': guidance_config['obj'], **guidance_config[configt.data.data.lower()]})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "score_fn_x = get_score_fn(sde_x, model_x, train=False, continuous=continuous)\n",
    "score_fn_adj = get_score_fn(sde_adj, model_adj, train=False, continuous=continuous)\n",
    "\n",
    "predictor_fn = ReverseDiffusionPredictor if predictor=='Reverse' else EulerMaruyamaPredictor \n",
    "corrector_fn = LangevinCorrector if corrector=='Langevin' else NoneCorrector\n",
    "\n",
    "predictor_obj_x = predictor_fn('x', sde_x, score_fn_x, probability_flow)\n",
    "corrector_obj_x = corrector_fn('x', sde_x, score_fn_x, snr, scale_eps, n_steps)\n",
    "\n",
    "predictor_obj_adj = predictor_fn('adj', sde_adj, score_fn_adj, probability_flow, guidance_args=guidance_args)\n",
    "corrector_obj_adj = corrector_fn('adj', sde_adj, score_fn_adj, snr, scale_eps, n_steps)\n",
    "\n",
    "gen_graph_list_zero = sample(predictor_obj_x, corrector_obj_x, predictor_obj_adj, corrector_obj_adj, init_x, init_adj, init_flags_iter)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "n_nodes_zero = np.array([g.number_of_nodes() for g in gen_graph_list_zero])\n",
    "n_edges_zero = np.array([g.number_of_edges() for g in gen_graph_list_zero])\n",
    "n_triangles_zero = np.array([sum(list(nx.triangles(g).values())) for g in gen_graph_list_zero])\n",
    "max_degrees_zero = np.array([max([x[1] for x in g.degree()]) for g in gen_graph_list_zero])\n",
    "n_edges_zero.mean(), n_edges_zero.std()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "adjs_zero = torch.zeros(len(gen_graph_list_zero), configt.data.max_node_num, configt.data.max_node_num)\n",
    "for i, G in enumerate(gen_graph_list_zero):\n",
    "    nG = G.number_of_nodes()\n",
    "    adjs_zero[i, :nG, :nG] = torch.tensor(nx.adjacency_matrix(G).todense())\n",
    "xs_zero = torch.zeros (len(gen_graph_list_zero), configt.data.max_node_num, configt.data.max_feat_num)\n",
    "if guidance_args.method_Z == \"communities\":\n",
    "    Zs_zero = compute_group_assignments_tensor(adjs_zero.to(device_id), guidance_args.n_com).to(device_id)\n",
    "else:\n",
    "    Zs_zero = predictor_obj_adj.Z.clone().to(device_id)\n",
    "dp1_zero = compute_dp1(adjs_zero.to(device_id), Zs_zero).mean().item()\n",
    "dp2_zero = compute_dp2(adjs_zero.to(device_id), Zs_zero).mean().item()\n",
    "dp2_zero_std = compute_dp2(adjs_zero.to(device_id), Zs_zero).std().item()\n",
    "nodedp1_zero = compute_nodedp1(adjs_zero.to(device_id), Zs_zero).mean().item()\n",
    "nodedp2_zero = compute_nodedp2(adjs_zero.to(device_id), Zs_zero).mean().item()\n",
    "nodedp2_zero_std = compute_nodedp2(adjs_zero.to(device_id), Zs_zero).std().item()\n",
    "across_comm_edges_zero = count_across_community_edges(adjs_zero.to(device_id), Zs_zero).mean().item()\n",
    "\n",
    "dp1_zero, dp2_zero, nodedp1_zero, nodedp2_zero, across_comm_edges_zero"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for i in range(5):\n",
    "    plt.subplot(5, 5, i+1)\n",
    "    nx.draw(gen_graph_list_loss[np.random.randint(len(gen_graph_list_loss))], with_labels=False, node_size=10)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Unconstrained"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "score_fn_x = get_score_fn(sde_x, model_x, train=False, continuous=continuous)\n",
    "score_fn_adj = get_score_fn(sde_adj, model_adj, train=False, continuous=continuous)\n",
    "\n",
    "predictor_fn = ReverseDiffusionPredictor if predictor=='Reverse' else EulerMaruyamaPredictor \n",
    "corrector_fn = LangevinCorrector if corrector=='Langevin' else NoneCorrector\n",
    "\n",
    "predictor_obj_x = predictor_fn('x', sde_x, score_fn_x, probability_flow)\n",
    "corrector_obj_x = corrector_fn('x', sde_x, score_fn_x, snr, scale_eps, n_steps)\n",
    "\n",
    "predictor_obj_adj = predictor_fn('adj', sde_adj, score_fn_adj, probability_flow)\n",
    "corrector_obj_adj = corrector_fn('adj', sde_adj, score_fn_adj, snr, scale_eps, n_steps)\n",
    "\n",
    "gen_graph_list_uncons = sample(predictor_obj_x, corrector_obj_x, predictor_obj_adj, corrector_obj_adj, init_x, init_adj, init_flags_iter)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "n_nodes_uncons = np.array([g.number_of_nodes() for g in gen_graph_list_uncons])\n",
    "n_edges_uncons = np.array([g.number_of_edges() for g in gen_graph_list_uncons])\n",
    "n_triangles_uncons = np.array([sum(list(nx.triangles(g).values())) for g in gen_graph_list_uncons])\n",
    "max_degrees_uncons = np.array([max([x[1] for x in g.degree()]) for g in gen_graph_list_uncons])\n",
    "n_edges_uncons.mean(), n_edges_uncons.std()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "adjs_uncons = torch.zeros(len(gen_graph_list_uncons), configt.data.max_node_num, configt.data.max_node_num)\n",
    "for i, G in enumerate(gen_graph_list_uncons):\n",
    "    nG = G.number_of_nodes()\n",
    "    adjs_uncons[i, :nG, :nG] = torch.tensor(nx.adjacency_matrix(G).todense())\n",
    "xs_uncons = torch.zeros (len(gen_graph_list_uncons), configt.data.max_node_num, configt.data.max_feat_num)\n",
    "if method_Z == \"communities\":\n",
    "    Zs_uncons = compute_group_assignments_tensor(adjs_uncons.to(device_id), guidance_args.n_com).to(device_id)\n",
    "else:\n",
    "    idxs_com = np.random.randint(0, guidance_args.get('n_com', 2), size=(adjs_uncons.shape[0], adjs_uncons.shape[1]))\n",
    "    Zs = torch.nn.functional.one_hot(torch.tensor(idxs_com), num_classes=guidance_args.get('n_com', 2)).float().to(device_id)\n",
    "    Zs_uncons = Zs.permute(0,2,1)\n",
    "dp1_uncons = compute_dp1(adjs_uncons.to(device_id), Zs_uncons).mean().item()\n",
    "dp2_uncons = compute_dp2(adjs_uncons.to(device_id), Zs_uncons).mean().item()\n",
    "dp2_uncons_std = compute_dp2(adjs_uncons.to(device_id), Zs_uncons).std().item()\n",
    "nodedp1_uncons = compute_nodedp1(adjs_uncons.to(device_id), Zs_uncons).mean().item()\n",
    "nodedp2_uncons = compute_nodedp2(adjs_uncons.to(device_id), Zs_uncons).mean().item()\n",
    "nodedp2_uncons_std = compute_nodedp2(adjs_uncons.to(device_id), Zs_uncons).std().item()\n",
    "across_comm_edges_uncons = count_across_community_edges(adjs_uncons.to(device_id), Zs_uncons).mean().item()\n",
    "\n",
    "dp1_uncons, dp2_uncons, nodedp1_uncons, nodedp2_uncons, across_comm_edges_uncons"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Evaluation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "f, ax = plt.subplots(4, 5, figsize=(20,12))\n",
    "\n",
    "obtain_colors = lambda Z, N: Z.argmax(axis=0).cpu().numpy()[:N]\n",
    "\n",
    "for i in range(5):\n",
    "    j = np.random.randint(len(gen_graph_list_greedy))\n",
    "    nx.draw(gen_graph_list_loss[j], ax=ax[0, i], node_size=30, node_color=obtain_colors(Zs_loss[j,:,:], gen_graph_list_loss[j].number_of_nodes()), cmap='coolwarm')\n",
    "    nx.draw(gen_graph_list_greedy[j], ax=ax[1, i], node_size=30, node_color=obtain_colors(Zs_greedy[j,:,:], gen_graph_list_greedy[j].number_of_nodes()), cmap='coolwarm')\n",
    "    nx.draw(gen_graph_list_zero[j], ax=ax[2, i], node_size=30, node_color=obtain_colors(Zs_zero[j,:,:], gen_graph_list_zero[j].number_of_nodes()), cmap='coolwarm')\n",
    "    nx.draw(gen_graph_list_uncons[j], ax=ax[3, i], node_size=30, node_color=obtain_colors(Zs_uncons[j,:,:], gen_graph_list_uncons[j].number_of_nodes()), cmap='coolwarm')\n",
    "\n",
    "methods = [\"GGDiff-G\", \"GGDiff-C\", \"GGDiff-Z\", \"Uncons.\"]\n",
    "for j in range(4):\n",
    "    for i in range(5):\n",
    "        ax[j,i].axis('on')\n",
    "    ax[j,0].set_ylabel(methods[j], fontsize=24)\n",
    "if not os.path.exists('results/fair_graph_generation'):\n",
    "    os.makedirs('results/fair_graph_generation')\n",
    "f.savefig(f'results/fair_graph_generation/{method_Z}-samples.pdf', bbox_inches='tight')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dp1_greedy, dp1_loss, dp1_zero, dp1_uncons"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dp2_greedy, dp2_loss, dp2_zero, dp2_uncons"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "factor = 8.\n",
    "pct_valid_sbm_greedy = np.array([is_sbm_graph(g, factor=factor) for g in gen_graph_list_greedy]).mean()\n",
    "pct_valid_sbm_loss = np.array([is_sbm_graph(g, factor=factor) for g in gen_graph_list_loss]).mean()\n",
    "pct_valid_sbm_zero = np.array([is_sbm_graph(g, factor=factor) for g in gen_graph_list_zero]).mean()\n",
    "pct_valid_sbm_uncons = np.array([is_sbm_graph(g, factor=factor) for g in gen_graph_list_uncons]).mean()\n",
    "pct_valid_sbm_test = np.array([is_sbm_graph(g, factor=factor) for g in test_graph_list]).mean()\n",
    "pct_valid_sbm_greedy, pct_valid_sbm_loss, pct_valid_sbm_zero, pct_valid_sbm_uncons, pct_valid_sbm_test"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "est_ps_greedy = [est_p_intra_inter(g) for g in gen_graph_list_greedy]\n",
    "est_ps_loss = [est_p_intra_inter(g) for g in gen_graph_list_loss]\n",
    "est_ps_zero = [est_p_intra_inter(g) for g in gen_graph_list_zero]\n",
    "est_ps_uncons = [est_p_intra_inter(g) for g in gen_graph_list_uncons]\n",
    "ratios_greedy = np.array([x[1] / x[0] for x in est_ps_greedy if x[0] > 0])\n",
    "ratios_loss = np.array([x[1] / x[0] for x in est_ps_loss if x[0] > 0])\n",
    "ratios_zero = np.array([x[1] / x[0] for x in est_ps_zero if x[0] > 0])\n",
    "ratios_uncons = np.array([x[1] / x[0] for x in est_ps_uncons if x[0] > 0])\n",
    "plt.hist(ratios_greedy, alpha=0.5, bins=20, label='Greedy')\n",
    "plt.hist(ratios_loss, alpha=0.5, bins=20, label='Loss')\n",
    "plt.hist(ratios_zero, alpha=0.5, bins=20, label='Zero')\n",
    "plt.hist(ratios_uncons, alpha=0.5, bins=20, label='Uncons')\n",
    "plt.legend()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for method in ['greedy', 'loss', 'zero', 'uncons']:\n",
    "    print(f\"Method: {method}\")\n",
    "    print(f\"DP1: {eval('dp1_' + method)}\")\n",
    "    print(f\"DP2: {eval('dp2_' + method)}\")\n",
    "    print(f\"Node DP1: {eval('nodedp1_' + method)}\")\n",
    "    print(f\"Node DP2: {eval('nodedp2_' + method)}\")\n",
    "    print(f\"Across community edges: {eval('across_comm_edges_' + method)}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(\"\\\\begin{table}[H]\")\n",
    "print(\"\\\\centering\")\n",
    "print(\"\\\\caption{Metrics for the fair graph generation.}\")\n",
    "print(\"\\\\label{tab:fair_graph_metrics}\")\n",
    "print(\"\\\\begin{tabular}{ccccc}\")\n",
    "print(\"\\\\toprule\")\n",
    "print(\"\\\\textbf{Method} & \\\\textbf{$\\\\Delta$ DP} & \\\\textbf{$\\\\Delta \\\\text{DP}_{\\\\text{node}}$ } & \\\\% Valid SBM \\\\\\\\\")\n",
    "print(\"\\\\midrule\")\n",
    "for method in ['greedy', 'loss', 'zero', 'uncons']:\n",
    "    print(method.capitalize() + \" & \" + \" & \".join([f\"{eval(metric + '_' + method):.4f} $\\\\pm$ {eval(metric + '_' + method + '_std'):.4f}\" for metric in ['dp2', 'nodedp2']]) + f\" & {100*eval('pct_valid_sbm_' + method):.4f} \\\\\\\\\")  \n",
    "print(\"\\\\bottomrule\")\n",
    "print(\"\\\\end{tabular}\")\n",
    "print(\"\\\\end{table}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "digress-env",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
