{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "9d9996c0",
   "metadata": {},
   "source": [
    "### Dataset for Missing Data with Uncertainty"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "829a6233",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import torch\n",
    "import random\n",
    "import pandas as pd\n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "import os\n",
    "import json\n",
    "from torch.distributions import *\n",
    "import pydgn\n",
    "import math\n",
    "from matplotlib import cm\n",
    "import sys\n",
    "from torch_geometric.utils import *\n",
    "# plt.xkcd()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d4daf37e",
   "metadata": {},
   "outputs": [],
   "source": [
    "seed = 0\n",
    "\n",
    "num_features = 1  # how many independent features to generate\n",
    "num_components = 3  # for each feature, the number of components for its own mixture model\n",
    "mean_deviation = 30\n",
    "max_std = 5"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c0c48302",
   "metadata": {},
   "source": [
    "### Plot the mixture"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f353b8e9",
   "metadata": {},
   "outputs": [],
   "source": [
    "random.seed(seed)\n",
    "np.random.seed(seed)\n",
    "torch.manual_seed(seed)\n",
    "\n",
    "# First, set seeds for each feature\n",
    "feature_seed = random.randint(0, sys.maxsize) % 100\n",
    "print(f'Chosen feature seed is {feature_seed}')\n",
    "\n",
    "random.seed(feature_seed)\n",
    "np.random.seed(feature_seed)\n",
    "torch.manual_seed(feature_seed)\n",
    "\n",
    "# mean between -100 and 100\n",
    "mean = torch.rand(num_components, num_features)*(200) - 100\n",
    "max_mean = mean + mean_deviation\n",
    "min_mean = mean - mean_deviation\n",
    "\n",
    "print(f'Chosen feature mean range is {mean}')\n",
    "\n",
    "# for each feature, I generate a different mixture model (hence 1 in the argument of torch.rand)\n",
    "mu, std = torch.rand(num_components, num_features)*(max_mean-min_mean) + min_mean, torch.rand(num_components, num_features)*max_std"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1cdd1e99",
   "metadata": {},
   "source": [
    "### Generate Graphs with a fixed number of communities 5"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "56f3872a",
   "metadata": {},
   "outputs": [],
   "source": [
    "import community as community_louvain\n",
    "import networkx as nx\n",
    "from numgraph.distributions import *\n",
    "from numgraph.utils import *\n",
    "from numpy.random import default_rng\n",
    "\n",
    "def _find_between_community_edges(g, partition):\n",
    "\n",
    "    edges = dict()\n",
    "\n",
    "    for (ni, nj) in g.edges():\n",
    "        ci = partition[ni]\n",
    "        cj = partition[nj]\n",
    "\n",
    "        if ci != cj:\n",
    "            try:\n",
    "                edges[(ci, cj)] += [(ni, nj)]\n",
    "            except KeyError:\n",
    "                edges[(ci, cj)] = [(ni, nj)]\n",
    "\n",
    "    return edges\n",
    "\n",
    "def _position_nodes(g, partition, **kwargs):\n",
    "\n",
    "    communities = dict()\n",
    "    for node, community in partition.items():\n",
    "        try:\n",
    "            communities[community] += [node]\n",
    "        except KeyError:\n",
    "            communities[community] = [node]\n",
    "    pos = dict()\n",
    "    for ci, nodes in communities.items():\n",
    "        subgraph = g.subgraph(nodes)\n",
    "        pos_subgraph = nx.spring_layout(subgraph, **kwargs)\n",
    "        pos.update(pos_subgraph)\n",
    "    return pos\n",
    "\n",
    "def _position_communities(g, partition, **kwargs):\n",
    "    # create a weighted graph, in which each node corresponds to a community,\n",
    "    # and each edge weight to the number of edges between communities\n",
    "    between_community_edges = _find_between_community_edges(g, partition)\n",
    "    communities = set(partition.values())\n",
    "    hypergraph = nx.DiGraph()\n",
    "    hypergraph.add_nodes_from(communities)\n",
    "    for (ci, cj), edges in between_community_edges.items():\n",
    "        hypergraph.add_edge(ci, cj, weight=len(edges))\n",
    "    # find layout for communities\n",
    "    pos_communities = nx.spring_layout(hypergraph, **kwargs)\n",
    "    # set node positions to position of community\n",
    "    pos = dict()\n",
    "    for node, community in partition.items():\n",
    "        pos[node] = pos_communities[community]\n",
    "    return pos\n",
    "\n",
    "def community_layout(g, partition):\n",
    "    pos_communities = _position_communities(g, partition, scale=3.)\n",
    "    pos_nodes = _position_nodes(g, partition, scale=1.)\n",
    "    # combine positions\n",
    "    pos = dict()\n",
    "    for node in g.nodes():\n",
    "        pos[node] = pos_communities[node] + pos_nodes[node]\n",
    "    return pos\n",
    "\n",
    "def plot_sbm(G, seed):\n",
    "    partition = community_louvain.best_partition(G, random_state=seed)\n",
    "    pos = community_layout(G, partition)\n",
    "    nx.draw(G, pos, node_color=list(partition.values()), arrowstyle='-|>')\n",
    "    plt.show()\n",
    "    \n",
    "# SBM\n",
    "print('SBM')\n",
    "block_size = [15, 5, 3]\n",
    "probs = [[0.5, 0.01, 0.01], \n",
    "         [0.01, 0.5, 0.01],\n",
    "         [0.01, 0.01, 0.5]]\n",
    "generator = lambda b, p, rng: erdos_renyi_coo(b, p)\n",
    "e, _ = stochastic_block_model_coo(block_size, probs, generator, rng = default_rng(seed))\n",
    "G = nx.from_edgelist(e)\n",
    "\n",
    "plot_sbm(G, seed=seed)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "79f2fca9",
   "metadata": {},
   "outputs": [],
   "source": [
    "from torch_geometric.data import Batch, Data\n",
    "from torch_geometric.utils import *\n",
    "from torch_geometric.transforms import RemoveIsolatedNodes\n",
    "if not os.path.exists('GENERATED_DATA/missing_data/'):\n",
    "    os.makedirs('GENERATED_DATA/missing_data/')\n",
    "\n",
    "transform = RemoveIsolatedNodes()\n",
    "\n",
    "# BACKUP\n",
    "# min_size = 50\n",
    "# max_size = 100\n",
    "# inter_max_prob = 0.01\n",
    "# inter_min_prob = 0.002\n",
    "# intra_max_prob = 0.25\n",
    "# intra_min_prob = 0.1\n",
    "min_size = 50\n",
    "max_size = 100\n",
    "inter_max_prob = 0.01\n",
    "inter_min_prob = 0.002\n",
    "intra_max_prob = 0.25\n",
    "intra_min_prob = 0.1\n",
    "\n",
    "num_graphs = 100\n",
    "per_graph_num_samples = 100"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "56fb88e7",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "random.seed(seed)\n",
    "np.random.seed(seed)\n",
    "torch.manual_seed(seed)\n",
    "\n",
    "num_communities = 5\n",
    "assert num_components == 3\n",
    "\n",
    "per_community_dirichlet = [\n",
    "    Dirichlet(torch.tensor([9., 1., 1.])),\n",
    "    Dirichlet(torch.tensor([1., 9., 1.])),\n",
    "    Dirichlet(torch.tensor([1., 1., 9.])),\n",
    "    Dirichlet(torch.tensor([2., 2., 1.])),\n",
    "    Dirichlet(torch.tensor([1., 1., 2.])),\n",
    "]\n",
    "assert len(per_community_dirichlet) == num_communities    \n",
    "\n",
    "plot = False\n",
    "\n",
    "dataset = []\n",
    "for graph in range(num_graphs):\n",
    "    if (graph+1)%100 == 0:\n",
    "        print(f'Processed graph {graph+1}')\n",
    "            \n",
    "    block_size = ((torch.rand(num_communities)*(max_size-min_size) + min_size)/num_communities).int()\n",
    "    community_assignment = torch.tensor([i for i in range(num_communities) for v in range(block_size[i])]).long()\n",
    "    \n",
    "    intra_probs = torch.rand(num_communities, num_communities)*(inter_max_prob-inter_min_prob) + inter_min_prob\n",
    "    intra_probs.fill_diagonal_(0.)\n",
    "    inter_probs = torch.eye(num_communities)*(torch.rand(num_communities, num_communities)*(intra_max_prob-intra_min_prob) + intra_min_prob)    \n",
    "    probs = intra_probs + inter_probs\n",
    "                       \n",
    "    generator = lambda b, p, rng: erdos_renyi_coo(b, p)\n",
    "    e, _ = stochastic_block_model_coo(block_size.tolist(), probs.tolist(), generator, directed=False)  \n",
    "    G = nx.from_edgelist(e)\n",
    "    graph_data = from_networkx(G)\n",
    "    \n",
    "    if plot:\n",
    "        plt.figure()\n",
    "        plot_sbm(G, seed=seed)\n",
    "\n",
    "    assert is_undirected(graph_data.edge_index)\n",
    "    \n",
    "    mixing_weights = []\n",
    "    for c in range(num_communities):\n",
    "        mixing_weights_per_community = per_community_dirichlet[c].sample((block_size[c],))        \n",
    "        mixing_weights.append(mixing_weights_per_community)\n",
    "\n",
    "    mixing_weights = torch.cat(mixing_weights, dim=0)\n",
    "\n",
    "    \n",
    "    if plot:\n",
    "        print(homophily(graph_data.edge_index, community_assignment))\n",
    "        fig = plt.figure()\n",
    "        ax = fig.add_subplot(projection='3d')\n",
    "        ax.scatter(mixing_weights[:,0], mixing_weights[:,1], mixing_weights[:,2])\n",
    "        ax.view_init(30, 30) \n",
    "    \n",
    "    graph_data = Data(x=mixing_weights, edge_index=graph_data.edge_index, y=community_assignment)\n",
    "    \n",
    "    # Remove isolated nodes\n",
    "    graph_data = transform(graph_data)\n",
    "    assert not torch.any(degree(graph_data.edge_index[1]) == 0)\n",
    "\n",
    "    dataset.append(graph_data)\n",
    "    if (graph+1)%num_graphs == 0:\n",
    "        torch.save(dataset, f'GENERATED_DATA/missing_data/data_list_{graph+1}_step1.pt')\n",
    "        dataset = []"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cafd70c1",
   "metadata": {},
   "source": [
    "### Load the dataset and perform one step of neighboring aggregation\n",
    "#### Linearly Combine the node feature and the neighborhood contribution"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "104e3ed8",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "random.seed(seed)\n",
    "np.random.seed(seed)\n",
    "torch.manual_seed(seed)\n",
    "\n",
    "from torch_scatter import scatter_min, scatter_max, scatter_mean, scatter_sum, scatter_std\n",
    "\n",
    "single_node_weight = 0\n",
    "struc_weight = 100\n",
    "assert single_node_weight+struc_weight == 100\n",
    "\n",
    "dataset_name = f'Synthetic_{single_node_weight}_{struc_weight}'\n",
    "\n",
    "print((single_node_weight/100), (struc_weight/100))\n",
    "\n",
    "plot = False\n",
    "\n",
    "# for i in range(num_graphs//100):\n",
    "#     index = (i+1)*100\n",
    "#     print(index)\n",
    "    \n",
    "#     dataset = torch.load(f'GENERATED_DATA/missing_data/data_list_{index}_step1.pt')\n",
    "    \n",
    "#     new_dataset = []\n",
    "#     num_samples = 0\n",
    "#     for g in dataset:\n",
    "        \n",
    "#         # aggregate cluster assignments for each feature --> distribution of cluster according to neighbors\n",
    "#         x_mean_aggr = scatter_mean(g.x[g.edge_index[0],:], g.edge_index[1], dim=0, out=torch.zeros_like(g.x[:,:]))\n",
    "#         structure_dependent_mixing_weights = x_mean_aggr\n",
    "#         weights = (single_node_weight/100)*g.x + (struc_weight/100)*structure_dependent_mixing_weights\n",
    "\n",
    "#         mix = Categorical(probs=weights)          \n",
    "#         comp = Independent(Normal(loc=mu.unsqueeze(0).repeat(g.x.shape[0], 1, 1), \n",
    "#                                   scale=std.unsqueeze(0).repeat(g.x.shape[0], 1, 1)),\n",
    "#                            1)\n",
    "#         mm = MixtureSameFamily(mix, comp)\n",
    "\n",
    "# #         print(mu.shape, std.shape, weights.shape)\n",
    "\n",
    "\n",
    "#         if plot:\n",
    "#             fig = plt.figure()\n",
    "#             ax = fig.add_subplot(1, 2, 1, projection='3d')              \n",
    "#             ax.scatter(g.x[:,0], g.x[:,1], g.x[:,2])\n",
    "#             ax.view_init(30, 30) \n",
    "\n",
    "#             ax = fig.add_subplot(1, 2, 2, projection='3d')              \n",
    "#             ax.scatter(weights[:,0], weights[:,1], weights[:,2])\n",
    "#             ax.view_init(30, 30) \n",
    "#             plt.show()\n",
    "\n",
    "#         for _ in range(per_graph_num_samples):\n",
    "\n",
    "#             if (num_samples+1)%1000 == 0:\n",
    "#                 print(num_samples)\n",
    "\n",
    "#             sample = mm.sample()\n",
    "            \n",
    "# #             print(sample.shape)\n",
    "# #             \n",
    "#             g_sample = Data(x=sample, edge_index=g.edge_index.clone(), y=g.y.clone())\n",
    "        \n",
    "\n",
    "#             new_dataset.append(g_sample)\n",
    "#             num_samples += 1\n",
    "\n",
    "#     if not os.path.exists(f'GENERATED_DATA/missing_data/{dataset_name}'):\n",
    "#         os.makedirs(f'GENERATED_DATA/missing_data/{dataset_name}')\n",
    "        \n",
    "#     torch.save(new_dataset, f'GENERATED_DATA/missing_data/{dataset_name}/data_list_{index}.pt')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "76e85539",
   "metadata": {},
   "source": [
    "### Now compute dataset statistics for each community, to see how the node distribution changed"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f84e1bff",
   "metadata": {},
   "outputs": [],
   "source": [
    "from torch_geometric.utils import *\n",
    "all_x_pre = []\n",
    "all_y_pre = []\n",
    "all_degree = []\n",
    "for i in range(num_graphs//100):\n",
    "    index = (i+1)*num_graphs\n",
    "    dataset = torch.load(f'GENERATED_DATA/missing_data/data_list_{index}_step1.pt')\n",
    "    \n",
    "    for g in dataset:\n",
    "        all_x_pre.append(g.x)\n",
    "        all_y_pre.append(g.y)\n",
    "        all_degree.append(degree(g.edge_index[1], num_nodes=g.x.shape[0]))\n",
    "        \n",
    "all_x_pre = torch.cat(all_x_pre, dim=0)\n",
    "all_y_pre = torch.cat(all_y_pre, dim=0)\n",
    "all_degree = torch.cat(all_degree, dim=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1e9ad825",
   "metadata": {},
   "outputs": [],
   "source": [
    "from torch_geometric.utils import *\n",
    "all_num_nodes = []\n",
    "all_num_edges = []\n",
    "for i in range(num_graphs//100):\n",
    "    index = (i+1)*num_graphs\n",
    "    dataset = torch.load(f'GENERATED_DATA/missing_data/{dataset_name}/data_list_{index}.pt')\n",
    "    \n",
    "    for g in dataset:\n",
    "        all_num_nodes.append(g.x.shape[0])\n",
    "        all_num_edges.append(g.edge_index.shape[1])\n",
    "        \n",
    "print(np.mean(all_num_nodes), np.mean(all_num_edges), len(all_num_nodes))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8c2fd2d7",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.hist(all_degree.numpy(), bins=40)\n",
    "print(torch.unique(all_degree, return_counts=True))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4da36ae2",
   "metadata": {},
   "outputs": [],
   "source": [
    "all_x = []\n",
    "all_y = []\n",
    "print(dataset_name)\n",
    "for i in range(num_graphs//100):\n",
    "    index = (i+1)*num_graphs\n",
    "    dataset = torch.load(f'GENERATED_DATA/missing_data/{dataset_name}/data_list_{index}.pt')\n",
    "    \n",
    "    for g in dataset:\n",
    "        all_x.append(g.x)\n",
    "        all_y.append(g.y)\n",
    "\n",
    "all_x = torch.cat(all_x, dim=0)\n",
    "all_y = torch.cat(all_y, dim=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8a720a0d",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "data = all_x.numpy()\n",
    "#plt.scatter(x=data[:10000, 10], y=data[:10000, 4], s=40, cmap='viridis')\n",
    "#sns.kdeplot(x=data[:10000, 0], y=data[:10000, 4], s=40, cmap='viridis')\n",
    "plt.hist(data[:20000, 0], bins=100)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c7617845",
   "metadata": {},
   "source": [
    "## This code evaluates the entire network"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "eee3c5fd",
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "import os.path as osp\n",
    "from torch_geometric.data import Batch\n",
    "\n",
    "from pydgn.evaluation.config import Config\n",
    "from pydgn.experiment.util import s2c\n",
    "\n",
    "outer_k = 0\n",
    "ASSESSMENT_FOLDER = f'GSPN_RESULTS/UNSUPERVISED/Synthetic_{per_comm_weight}_{struc_weight}/synthetic_gaussian_nomask_2layer_SyntheticDataset/MODEL_ASSESSMENT'\n",
    "OUTER_FOLD_BASE = 'OUTER_FOLD_'\n",
    "SELECTION_FOLDER = 'MODEL_SELECTION'\n",
    "WINNER_CONFIG = 'winner_config.json'\n",
    "\n",
    "outer_folder = osp.join(ASSESSMENT_FOLDER, OUTER_FOLD_BASE + str(outer_k + 1))\n",
    "config_fname = osp.join(outer_folder, SELECTION_FOLDER, WINNER_CONFIG)\n",
    "\n",
    "dataset_name = f'SyntheticDataset_{per_comm_weight}_{struc_weight}'\n",
    "splits_filepath = f'DATA_SPLITS/{dataset_name}/SyntheticDataset_outer1_inner1.splits'\n",
    "outer_folds = 1\n",
    "inner_folds = 1\n",
    "\n",
    "with open(config_fname, 'r') as f:\n",
    "    best_config = json.load(f)\n",
    "\n",
    "config_with_metadata = Config(best_config['config'])\n",
    "\n",
    "dataset_getter_class = s2c(config_with_metadata.dataset_getter)\n",
    "dataset_getter = dataset_getter_class(config_with_metadata.data_root,\n",
    "                                      splits_filepath,\n",
    "                                      s2c(config_with_metadata.dataset_class),\n",
    "                                      dataset_name,\n",
    "                                      s2c(config_with_metadata.data_loader),\n",
    "                                      config_with_metadata.data_loader_args,\n",
    "                                      outer_folds,\n",
    "                                      inner_folds)\n",
    "\n",
    "dataset_getter.set_inner_k(0)\n",
    "dataset_getter.set_outer_k(0)\n",
    "\n",
    "# not really used\n",
    "dataset_getter.set_exp_seed(0)\n",
    "\n",
    "batch_size = 32\n",
    "shuffle = False\n",
    "\n",
    "# Instantiate the Dataset\n",
    "train_loader = dataset_getter.get_outer_train(batch_size=batch_size, shuffle=shuffle)\n",
    "val_loader = dataset_getter.get_outer_val(batch_size=batch_size, shuffle=shuffle)\n",
    "test_loader = dataset_getter.get_outer_test(batch_size=batch_size, shuffle=shuffle)\n",
    "\n",
    "# Call this after the loaders: the datasets may need to be instantiated with additional parameters\n",
    "dim_node_features = dataset_getter.get_dim_node_features()\n",
    "dim_edge_features = dataset_getter.get_dim_edge_features()\n",
    "dim_target = dataset_getter.get_dim_target()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4f741ad5",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch.nn as nn\n",
    "from model_log import GSPN\n",
    "device = 'cuda:3'\n",
    "ckpt = torch.load(f'GSPN_RESULTS/UNSUPERVISED/Synthetic_{per_comm_weight}_{struc_weight}/synthetic_gaussian_nomask_2layer_SyntheticDataset/MODEL_ASSESSMENT/OUTER_FOLD_1/final_run1/best_checkpoint.pth', map_location='cpu')\n",
    "\n",
    "model = GSPN(dim_node_features, 0, dim_node_features, None, config_with_metadata['supervised_config'])\n",
    "\n",
    "model_state = ckpt['model_state']\n",
    "model.load_state_dict(ckpt['model_state'])\n",
    "model.to(device)\n",
    "model.eval()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8aceda4f",
   "metadata": {},
   "outputs": [],
   "source": [
    "masked_nodes = []\n",
    "non_masked_nodes = []\n",
    "x = []\n",
    "x_imputed = []\n",
    "edge_index = []\n",
    "perc_masked_features = []\n",
    "log_lik = []\n",
    "\n",
    "curr_node_id = 0\n",
    "\n",
    "for batch in test_loader:\n",
    "    # Move data to device\n",
    "    batch.to(device)\n",
    "\n",
    "    output, embs, extra = model.forward(batch)\n",
    "    \n",
    "    edge_index.append(batch.edge_index.to('cpu') + curr_node_id)\n",
    "    curr_node_id += batch.num_nodes\n",
    "    \n",
    "    perc_masked_features.append(batch.perc_masked_features.to('cpu'))\n",
    "    \n",
    "    # Move output to cpu\n",
    "    batch.to('cpu')\n",
    "    embs.to('cpu')\n",
    "    for t in extra:\n",
    "        if t is not None:\n",
    "            t.to('cpu')\n",
    "    if output is not None:\n",
    "        output.to('cpu')\n",
    "    \n",
    "    log_lik_batch, _, _, x_batch, x_imputed_batch, masked_nodes_batch, _, _, _, _ = extra\n",
    "    \n",
    "    log_lik.append(log_lik_batch.to('cpu'))\n",
    "    x.append(x_batch.to('cpu'))\n",
    "    x_imputed.append(x_imputed_batch.to('cpu'))\n",
    "    masked_nodes.append(masked_nodes_batch.to('cpu'))\n",
    "\n",
    "log_lik = torch.cat(log_lik, dim=0)\n",
    "x = torch.cat(x, dim=0)\n",
    "x_imputed = torch.cat(x_imputed, dim=0)\n",
    "masked_nodes = torch.cat(masked_nodes, dim=0)\n",
    "edge_index = torch.cat(edge_index, dim=1)\n",
    "perc_masked_features = torch.cat(perc_masked_features, dim=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e88208f0",
   "metadata": {},
   "outputs": [],
   "source": [
    "model.to('cpu')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3b25870b",
   "metadata": {},
   "outputs": [],
   "source": [
    "## Using the model's prediction, compute MSE for missing features and bin the results according to percentage of masked features for each node"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "67387541",
   "metadata": {},
   "outputs": [],
   "source": [
    "x.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d137aa7b",
   "metadata": {},
   "outputs": [],
   "source": [
    "x = x.clone()\n",
    "x_imputed = x_imputed.clone()\n",
    "x_mvi = x.clone()\n",
    "degree_vec = degree(edge_index[1], num_nodes=x.shape[0])\n",
    "\n",
    "# Ensure MSE is 0 for\n",
    "x[non_masked_nodes] = 0.\n",
    "x_imputed[non_masked_nodes] = 0.\n",
    "\n",
    "x_mvi[masked_nodes] = torch.nan\n",
    "mv = torch.nanmean(x_mvi, dim=0, keepdim=True).repeat(x.shape[0], 1)\n",
    "x_mvi[masked_nodes] = mv[masked_nodes]\n",
    "\n",
    "\n",
    "# mean value imputation\n",
    "mse_per_vertex_mvi = torch.nn.functional.mse_loss(x, x_mvi, reduction='none').mean(dim=1)\n",
    "\n",
    "\n",
    "mse_per_vertex = torch.nn.functional.mse_loss(x, x_imputed, reduction='none').mean(dim=1)\n",
    "mse_per_vertex.shape\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f61818ba",
   "metadata": {},
   "outputs": [],
   "source": [
    "xs = []\n",
    "ys = []\n",
    "ys_mvi = []\n",
    "for i in range(1, 21):\n",
    "    min_perc = (i-1)/20\n",
    "    max_perc = i/20\n",
    "    bin_mask = torch.logical_and((perc_masked_features >= min_perc), (perc_masked_features <= max_perc))  \n",
    "    \n",
    "    bin_values = log_lik[bin_mask].mean().detach().numpy()\n",
    "    bin_values_mvi = log_lik[bin_mask].mean().log().detach().numpy()\n",
    "    \n",
    "    #bin_values = mse_per_vertex[bin_mask].mean().detach().numpy()\n",
    "    #bin_values_mvi = mse_per_vertex_mvi[bin_mask].mean().log().detach().numpy()\n",
    "    \n",
    "    xs.append(i)\n",
    "    ys.append(bin_values)\n",
    "    ys_mvi.append(bin_values_mvi)\n",
    "    \n",
    "xs = torch.tensor(xs).numpy()\n",
    "\n",
    "unique_degrees = torch.unique(degree_vec).int().numpy().tolist()\n",
    "\n",
    "heatmap = torch.zeros(len(unique_degrees), 20).numpy()\n",
    "for deg_id in range(len(unique_degrees)):\n",
    "    for i in range(1, 21):\n",
    "        min_perc = (i-1)/20\n",
    "        max_perc = i/20\n",
    "        bin_mask = torch.logical_and((perc_masked_features >= min_perc), (perc_masked_features <= max_perc))  \n",
    "\n",
    "        degree_mask = (degree_vec == unique_degrees[deg_id]) \n",
    "        bin_mask = torch.logical_and(bin_mask, degree_mask)\n",
    "        \n",
    "        heatmap[deg_id, i-1] = log_lik[bin_mask].mean().detach().numpy()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a7988cd2",
   "metadata": {},
   "outputs": [],
   "source": [
    "log_lik.mean()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "58afeef1",
   "metadata": {},
   "outputs": [],
   "source": [
    "ax = sns.heatmap(heatmap)\n",
    "ax.invert_yaxis()\n",
    "plt.yticks(np.arange(len(unique_degrees)))\n",
    "ax.set_yticklabels(unique_degrees)\n",
    "\n",
    "x_ticks_labels = [f'{(i-1)*5}%-{(i)*5}%' for i in range(1, 21)]\n",
    "plt.xticks(np.arange(1, 21))\n",
    "ax.set_xticklabels(x_ticks_labels, rotation=-30)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ad3104c3",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch.nn as nn\n",
    "from model_log import GSPN\n",
    "device = 'cuda:3'\n",
    "ckpt = torch.load(f'GSPN_RESULTS/UNSUPERVISED/Synthetic_{per_comm_weight}_{struc_weight}/synthetic_gaussian_nomask_SyntheticDataset/MODEL_ASSESSMENT/OUTER_FOLD_1/final_run1/best_checkpoint.pth', map_location='cpu')\n",
    "\n",
    "outer_k = 0\n",
    "ASSESSMENT_FOLDER = f'GSPN_RESULTS/UNSUPERVISED/Synthetic_{per_comm_weight}_{struc_weight}/synthetic_gaussian_nomask_SyntheticDataset/MODEL_ASSESSMENT'\n",
    "\n",
    "outer_folder = osp.join(ASSESSMENT_FOLDER, OUTER_FOLD_BASE + str(outer_k + 1))\n",
    "config_fname = osp.join(outer_folder, SELECTION_FOLDER, WINNER_CONFIG)\n",
    "\n",
    "with open(config_fname, 'r') as f:\n",
    "    best_config = json.load(f)\n",
    "\n",
    "config_with_metadata = Config(best_config['config'])\n",
    "\n",
    "model = GSPN(dim_node_features, 0, dim_node_features, None, config_with_metadata['supervised_config'])\n",
    "\n",
    "model_state = ckpt['model_state']\n",
    "model.load_state_dict(ckpt['model_state'])\n",
    "model.to(device)\n",
    "model.eval()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "95e9a2ea",
   "metadata": {},
   "outputs": [],
   "source": [
    "masked_nodes = []\n",
    "non_masked_nodes = []\n",
    "x = []\n",
    "x_imputed = []\n",
    "edge_index = []\n",
    "perc_masked_features = []\n",
    "log_lik_gmm = []\n",
    "\n",
    "curr_node_id = 0\n",
    "\n",
    "for batch in test_loader:\n",
    "    # Move data to device\n",
    "    batch.to(device)\n",
    "\n",
    "    output, embs, extra = model.forward(batch)\n",
    "    \n",
    "    edge_index.append(batch.edge_index.to('cpu') + curr_node_id)\n",
    "    curr_node_id += batch.num_nodes\n",
    "    \n",
    "    perc_masked_features.append(batch.perc_masked_features.to('cpu'))\n",
    "    \n",
    "    # Move output to cpu\n",
    "    batch.to('cpu')\n",
    "    embs.to('cpu')\n",
    "    for t in extra:\n",
    "        if t is not None:\n",
    "            t.to('cpu')\n",
    "    if output is not None:\n",
    "        output.to('cpu')\n",
    "    \n",
    "    log_lik_batch, _, _, x_batch, x_imputed_batch, masked_nodes_batch, _, _, _, _ = extra\n",
    "    \n",
    "    log_lik_gmm.append(log_lik_batch.to('cpu'))\n",
    "    x.append(x_batch.to('cpu'))\n",
    "    x_imputed.append(x_imputed_batch.to('cpu'))\n",
    "    masked_nodes.append(masked_nodes_batch.to('cpu'))\n",
    "\n",
    "log_lik_gmm = torch.cat(log_lik_gmm, dim=0)\n",
    "x = torch.cat(x, dim=0)\n",
    "x_imputed = torch.cat(x_imputed, dim=0)\n",
    "masked_nodes = torch.cat(masked_nodes, dim=0)\n",
    "edge_index = torch.cat(edge_index, dim=1)\n",
    "perc_masked_features = torch.cat(perc_masked_features, dim=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "15baae63",
   "metadata": {},
   "outputs": [],
   "source": [
    "model.to('cpu')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c79705d2",
   "metadata": {},
   "outputs": [],
   "source": [
    "x = x.clone()\n",
    "x_imputed = x_imputed.clone()\n",
    "x_mvi = x.clone()\n",
    "\n",
    "# Ensure MSE is 0 for\n",
    "x[non_masked_nodes] = 0.\n",
    "x_imputed[non_masked_nodes] = 0.\n",
    "\n",
    "mse_per_vertex_gmm = torch.nn.functional.mse_loss(x, x_imputed, reduction='none').mean(dim=1)\n",
    "mse_per_vertex_gmm.shape\n",
    "\n",
    "xs = []\n",
    "ys_gmm = []\n",
    "\n",
    "\n",
    "for i in range(1, 21):\n",
    "    min_perc = (i-1)/20\n",
    "    max_perc = i/20\n",
    "    bin_mask = torch.logical_and((perc_masked_features >= min_perc), (perc_masked_features <= max_perc))  \n",
    "    \n",
    "    bin_values_gmm = log_lik_gmm[bin_mask].mean().detach().numpy()\n",
    "    #bin_values_gmm = mse_per_vertex_gmm[bin_mask].mean().detach().numpy()\n",
    "    \n",
    "    xs.append(i)\n",
    "    ys_gmm.append(bin_values_gmm)\n",
    "\n",
    "xs = torch.tensor(xs).numpy()\n",
    "\n",
    "unique_degrees = torch.unique(degree_vec).int().numpy().tolist()\n",
    "\n",
    "heatmap_gmm = torch.zeros(len(unique_degrees), 20).numpy()\n",
    "for deg_id in range(len(unique_degrees)):\n",
    "    for i in range(1, 21):\n",
    "        min_perc = (i-1)/20\n",
    "        max_perc = i/20\n",
    "        bin_mask = torch.logical_and((perc_masked_features >= min_perc), (perc_masked_features <= max_perc))  \n",
    "\n",
    "        degree_mask = (degree_vec == unique_degrees[deg_id]) \n",
    "        bin_mask = torch.logical_and(bin_mask, degree_mask)\n",
    "        \n",
    "        heatmap_gmm[deg_id, i-1] = log_lik_gmm[bin_mask].mean().detach().numpy()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "59845317",
   "metadata": {},
   "outputs": [],
   "source": [
    "ax = sns.heatmap(heatmap_gmm)\n",
    "ax.invert_yaxis()\n",
    "plt.yticks(np.arange(len(unique_degrees)))\n",
    "ax.set_yticklabels(unique_degrees)\n",
    "\n",
    "x_ticks_labels = [f'{(i-1)*5}%-{(i)*5}%' for i in range(1, 21)]\n",
    "plt.xticks(np.arange(1, 21))\n",
    "ax.set_xticklabels(x_ticks_labels, rotation=-30)\n",
    "\n",
    "plt.figure()\n",
    "ax = sns.heatmap(heatmap-heatmap_gmm)\n",
    "ax.invert_yaxis()\n",
    "plt.yticks(np.arange(len(unique_degrees)))\n",
    "ax.set_yticklabels(unique_degrees)\n",
    "\n",
    "x_ticks_labels = [f'{(i-1)*5}%-{(i)*5}%' for i in range(1, 21)]\n",
    "plt.xticks(np.arange(1, 21))\n",
    "ax.set_xticklabels(x_ticks_labels, rotation=-30)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3e088ab1",
   "metadata": {},
   "outputs": [],
   "source": [
    "from baseline_mask import *\n",
    "model = MeanAggregation(dim_node_features, 0, dim_node_features, None, {})\n",
    "model.to(device)\n",
    "\n",
    "masked_nodes = []\n",
    "non_masked_nodes = []\n",
    "x = []\n",
    "x_imputed = []\n",
    "edge_index = []\n",
    "perc_masked_features = []\n",
    "\n",
    "curr_node_id = 0\n",
    "\n",
    "for batch in test_loader:\n",
    "    # Move data to device\n",
    "    batch.to(device)\n",
    "\n",
    "    output, embs, extra = model.forward(batch)\n",
    "    \n",
    "    edge_index.append(batch.edge_index.to('cpu') + curr_node_id)\n",
    "    curr_node_id += batch.num_nodes\n",
    "    \n",
    "    perc_masked_features.append(batch.perc_masked_features.to('cpu'))\n",
    "    \n",
    "    # Move output to cpu\n",
    "    batch.to('cpu')\n",
    "    embs.to('cpu')\n",
    "    for t in extra:\n",
    "        if t is not None:\n",
    "            t.to('cpu')\n",
    "    if output is not None:\n",
    "        output.to('cpu')\n",
    "    \n",
    "    _, _, _, x_batch, x_imputed_batch, masked_nodes_batch, _, _, _, _ = extra\n",
    "    \n",
    "    x.append(x_batch.to('cpu'))\n",
    "    x_imputed.append(x_imputed_batch.to('cpu'))\n",
    "    masked_nodes.append(masked_nodes_batch.to('cpu'))\n",
    "\n",
    "x = torch.cat(x, dim=0)\n",
    "x_imputed = torch.cat(x_imputed, dim=0)\n",
    "masked_nodes = torch.cat(masked_nodes, dim=0)\n",
    "edge_index = torch.cat(edge_index, dim=1)\n",
    "perc_masked_features = torch.cat(perc_masked_features, dim=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b6c918a8",
   "metadata": {},
   "outputs": [],
   "source": [
    "model.to('cpu')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "64e8761e",
   "metadata": {},
   "outputs": [],
   "source": [
    "x = x.clone()\n",
    "x_imputed = x_imputed.clone()\n",
    "x_mvi = x.clone()\n",
    "\n",
    "# Ensure MSE is 0 for\n",
    "x[non_masked_nodes] = 0.\n",
    "x_imputed[non_masked_nodes] = 0.\n",
    "\n",
    "mse_per_vertex_baseline = torch.nn.functional.mse_loss(x, x_imputed, reduction='none').mean(dim=1)\n",
    "mse_per_vertex_baseline.shape\n",
    "\n",
    "xs = []\n",
    "ys_baseline = []\n",
    "\n",
    "\n",
    "for i in range(1, 21):\n",
    "    min_perc = (i-1)/20\n",
    "    max_perc = i/20\n",
    "    bin_mask = torch.logical_and((perc_masked_features >= min_perc), (perc_masked_features <= max_perc))  \n",
    "    \n",
    "    bin_values_baseline = mse_per_vertex_baseline[bin_mask].mean().detach().numpy()\n",
    "    \n",
    "    xs.append(i)\n",
    "    ys_baseline.append(bin_values_baseline)\n",
    "\n",
    "xs = torch.tensor(xs).numpy()\n",
    "\n",
    "unique_degrees = torch.unique(degree_vec).int().numpy().tolist()\n",
    "\n",
    "mse_heatmap_baseline = torch.zeros(len(unique_degrees), 20).numpy()\n",
    "for deg_id in range(len(unique_degrees)):\n",
    "    for i in range(1, 21):\n",
    "        min_perc = (i-1)/20\n",
    "        max_perc = i/20\n",
    "        bin_mask = torch.logical_and((perc_masked_features >= min_perc), (perc_masked_features <= max_perc))  \n",
    "\n",
    "        degree_mask = (degree_vec == unique_degrees[deg_id]) \n",
    "        bin_mask = torch.logical_and(bin_mask, degree_mask)\n",
    "        \n",
    "        mse_heatmap_baseline[deg_id, i-1] = mse_per_vertex_baseline[bin_mask].mean().detach().numpy()\n",
    "\n",
    "plt.scatter(x_imputed[:1000, 0], x_imputed[:1000, 1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2344e2bb",
   "metadata": {},
   "outputs": [],
   "source": [
    "neigh_aggr_score = torch.nn.functional.mse_loss(x, x_imputed, reduction='mean').mean(dim=0)\n",
    "print(f'DGN score is {neigh_aggr_score}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "64c8e5de",
   "metadata": {},
   "outputs": [],
   "source": [
    "ax = sns.heatmap(mse_heatmap_baseline)\n",
    "ax.invert_yaxis()\n",
    "plt.yticks(np.arange(len(unique_degrees)))\n",
    "ax.set_yticklabels(unique_degrees)\n",
    "\n",
    "x_ticks_labels = [f'{(i-1)*5}%-{(i)*5}%' for i in range(1, 21)]\n",
    "plt.xticks(np.arange(1, 21))\n",
    "ax.set_xticklabels(x_ticks_labels, rotation=-30)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0a29730c",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.style.use('seaborn-colorblind')\n",
    "\n",
    "fig, ax = plt.subplots(1,1) \n",
    "plt.bar(xs-0.2, ys_gmm, width=0.2, label='GMM', fill=False, hatch='')\n",
    "x_ticks_labels = [f'{(i-1)*5}%-{(i)*5}%' for i in range(1, 21)]\n",
    "plt.xticks(np.arange(1, 21))\n",
    "ax.set_xticklabels(x_ticks_labels, rotation=-90)\n",
    "plt.ylabel('Missing features MSE')\n",
    "plt.xlabel('Percentage of masked features per node')\n",
    "plt.bar(xs, ys, width=0.2, label='BGC', fill=True, hatch='///')\n",
    "#plt.bar(xs+0.2, ys_baseline, width=0.2, label='Baseline')\n",
    "plt.legend()\n",
    "plt.tight_layout()\n",
    "plt.savefig('perf_vs_percentage_masking.png', dpi=350)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9140f20d",
   "metadata": {},
   "outputs": [],
   "source": [
    "ys_gmm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9c7cf62f",
   "metadata": {},
   "outputs": [],
   "source": [
    "ys"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2e3daa24",
   "metadata": {},
   "outputs": [],
   "source": [
    "for i in range(1, 21):\n",
    "    min_perc = (i-1)/20\n",
    "    max_perc = i/20\n",
    "    bin_mask = torch.logical_and((perc_masked_features >= min_perc), (perc_masked_features <= max_perc))\n",
    "    print(bin_mask.sum())"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.16"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
