{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# load the hypergraph\n",
    "import yaml\n",
    "import logging\n",
    "import itertools\n",
    "import os\n",
    "import sys\n",
    "import random\n",
    "\n",
    "from scipy import sparse\n",
    "import numpy as np\n",
    "import seaborn as sns\n",
    "import networkx as nx\n",
    "import community\n",
    "\n",
    "import torch\n",
    "import torch.nn.functional as F\n",
    "from torch_geometric.nn import GCNConv, SAGEConv\n",
    "from torch_geometric.data import Data\n",
    "\n",
    "from cell.utils import link_prediction_performance\n",
    "from cell.cell import Cell, EdgeOverlapCriterion, LinkPredictionCriterion\n",
    "from cell.graph_statistics import compute_graph_statistics\n",
    "\n",
    "from utils import load_graphs\n",
    "from cliques import compute_cliques"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class GNN(torch.nn.Module):\n",
    "    def __init__(self, node_features):\n",
    "        super().__init__()\n",
    "        # GCN initialization\n",
    "        self.conv1 = SAGEConv(node_features, 128)\n",
    "        self.conv2 = SAGEConv(128, 128)\n",
    "        self.bn = torch.nn.BatchNorm1d(128)\n",
    "        \n",
    "        # self.conv2 = GCNConv(128, 128)\n",
    "\n",
    "    def forward(self, data):\n",
    "        x, edge_index = data.x, data.edge_index\n",
    "        x = self.conv1(x, edge_index)\n",
    "        x = F.relu(x)\n",
    "        x = self.bn(x)\n",
    "        x = self.conv2(x, edge_index)\n",
    "\n",
    "        return x\n",
    "\n",
    "\n",
    "def save_hypergraph(hg, path):\n",
    "    with open(path, 'w') as f:\n",
    "        for edge in hg:\n",
    "            f.write(' '.join(map(str,edge)) + '\\n')\n",
    "\n",
    "\n",
    "def hypergraph_metrics(hg):\n",
    "    # original hypergraph\n",
    "    num_edges = len(hg)\n",
    "    nodes = set()\n",
    "    node_degrees = {}\n",
    "    for edge in hg:\n",
    "        for node in edge:\n",
    "            nodes.add(node)\n",
    "            node_degrees[node] = node_degrees.get(node, 0) + 1\n",
    "    num_nodes = len(nodes)\n",
    "    \n",
    "    # density\n",
    "    density = num_edges / num_nodes\n",
    "\n",
    "    # Average size\n",
    "    avg_size = sum(len(edge) for edge in hg) / num_edges\n",
    "\n",
    "    # Average degree\n",
    "    avg_degree = sum(node_degrees.values()) / num_nodes\n",
    "\n",
    "\n",
    "    # projected graph\n",
    "    G = nx.Graph()\n",
    "    # Add all nodes from the hypergraph\n",
    "    nodes = set(node for edge in hg for node in edge)\n",
    "    G.add_nodes_from(nodes)\n",
    "    # For each hyperedge, create a clique\n",
    "    for edge in hg:\n",
    "        # Add edges between all pairs of nodes in the hyperedge\n",
    "        G.add_edges_from(itertools.combinations(edge, 2))\n",
    "    \n",
    "    part_G = community.best_partition(G)\n",
    "    mod_G = community.modularity(part_G, G)\n",
    "\n",
    "\n",
    "    # bipartite graph\n",
    "    B = nx.Graph()\n",
    "    # Add nodes for the original vertices (left set)\n",
    "    left_nodes = set(node for edge in hg for node in edge)\n",
    "    B.add_nodes_from(left_nodes, bipartite=0)\n",
    "    # Add nodes for the hyperedges (right set)\n",
    "    right_nodes = [f'e{i}' for i in range(len(hg))]\n",
    "    B.add_nodes_from(right_nodes, bipartite=1)\n",
    "    # Add edges between vertices and their corresponding hyperedges\n",
    "    for i, edge in enumerate(hg):\n",
    "        for node in edge:\n",
    "            B.add_edge(node, f'e{i}')\n",
    "\n",
    "\n",
    "    part_B = community.best_partition(B)\n",
    "    mod_B = community.modularity(part_B, B)\n",
    "\n",
    "    return {\n",
    "        \"density\": density,\n",
    "        \"average_size\": avg_size,\n",
    "        \"average_degree\": avg_degree,\n",
    "        \"coefficient\": nx.average_clustering(G),\n",
    "        \"G_modularity\": mod_G,\n",
    "        \"B_modularity\": mod_B\n",
    "    }\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "logging.basicConfig(level=logging.INFO)\n",
    "logger = logging.getLogger()\n",
    "logger.setLevel(logging.INFO)\n",
    "\n",
    "config  = yaml.safe_load(open('./config.yml'))\n",
    "config['dataset'] = 'NDC-classes'\n",
    "graphs = load_graphs(config, logger)\n",
    "config['beta'] = len(graphs['simplicies_train']) * 10\n",
    "\n",
    "cliques = compute_cliques(graphs, config, logger)\n",
    "\n",
    "# data = np.array([len(s) for s in graphs['simplicies_train']])\n",
    "# hist, bins = np.histogram(data, bins=np.linspace(1, 8, 8))\n",
    "# sns.displot(data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from torch_geometric.nn import Node2Vec\n",
    "\n",
    "graph_adjacency_matrix, weighted_graph_adjacency_matrix = nx.to_numpy_array(graphs['G_train'], nodelist=sorted(graphs['G_train'].nodes())), nx.to_numpy_array(graphs['G_weighted'], nodelist=sorted(graphs['G_train'].nodes()))\n",
    "\n",
    "edge_index = torch.tensor(np.array(graph_adjacency_matrix.nonzero()), dtype=torch.long)\n",
    "data = Data(edge_index=edge_index)\n",
    "\n",
    "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
    "\n",
    "model = Node2Vec(\n",
    "    data.edge_index,\n",
    "    embedding_dim=50,\n",
    "    walks_per_node=10,\n",
    "    walk_length=20,\n",
    "    context_size=10,\n",
    "    p=1.0,\n",
    "    q=1.0,\n",
    "    num_negative_samples=1,\n",
    ").to(device)\n",
    "\n",
    "optimizer = torch.optim.Adam(model.parameters(), lr=0.01)\n",
    "loader = model.loader(batch_size=128, shuffle=True, num_workers=4)\n",
    "\n",
    "pos_rw, neg_rw = next(iter(loader))\n",
    "\n",
    "model.train()\n",
    "for pos_rw, neg_rw in loader:\n",
    "    optimizer.zero_grad()\n",
    "    loss = model.loss(pos_rw.to(device), neg_rw.to(device))\n",
    "    loss.backward()\n",
    "    optimizer.step()\n",
    "    # print(loss.item())\n",
    "\n",
    "embeddings = model()\n",
    "embeddings.requires_grad = False\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "graph_adjacency_matrix, weighted_graph_adjacency_matrix = nx.to_numpy_array(graphs['G_train'], nodelist=sorted(graphs['G_train'].nodes())), nx.to_numpy_array(graphs['G_weighted'], nodelist=sorted(graphs['G_train'].nodes()))\n",
    "edge_index = torch.tensor(np.array(graph_adjacency_matrix.nonzero()), dtype=torch.long)\n",
    "edge_value = weighted_graph_adjacency_matrix[graph_adjacency_matrix.nonzero()]\n",
    "\n",
    "# training for CELL\n",
    "data = Data(x=embeddings, edge_index=edge_index)\n",
    "model = GNN(50)\n",
    "optimizer = torch.optim.Adam(model.parameters(), lr=0.005, weight_decay=5e-4)\n",
    "model.train()\n",
    "for epoch in range(200):\n",
    "    optimizer.zero_grad()\n",
    "    out = model(data)\n",
    "    src, dst = edge_index\n",
    "    score = (out[src] * out[dst]).sum(dim=-1)\n",
    "    loss = F.mse_loss(score, torch.tensor(edge_value, dtype=torch.float))\n",
    "    loss.backward()\n",
    "    optimizer.step()\n",
    "    if (epoch + 1) % 10 == 0:\n",
    "        print(f'epoch: {epoch}, loss: {loss.item()}')\n",
    "# edge_index = torch.tensor(np.array(graph.nonzero()), dtype=torch.long)\n",
    "\n",
    "# training for CELL\n",
    "sparse_matrix = sparse.csr_matrix(graph_adjacency_matrix)\n",
    "cell_model = Cell(A=sparse_matrix,\n",
    "             H=10,\n",
    "             callbacks=[EdgeOverlapCriterion(invoke_every=10, edge_overlap_limit=.80)])\n",
    "cell_model.train(steps=400,\n",
    "            optimizer_fn=torch.optim.Adam,\n",
    "            optimizer_args={'lr': 0.1,\n",
    "                            'weight_decay': 1e-7})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import utils\n",
    "# from utils import lazy_clique_edge_cover\n",
    "from importlib import reload\n",
    "\n",
    "# reconstruct the hypergraph by clique cover\n",
    "# YOU GUY!!!!!!!!!!!!!!!!!!!!!!\n",
    "# BAD API!!!!!!!!!!!!!!!!!!!!!!\n",
    "# G = graphs['G_weighted']\n",
    "# weighted_adjacency_matrix = nx.to_numpy_array(G, nodelist=sorted(G.nodes()))\n",
    "\n",
    "# # sampling cliques\n",
    "# os.remove(f'{config['data_dir']}/{config['dataset']}/cliques_train.pkl')\n",
    "# os.remove(f'{config['data_dir']}/{config['dataset']}/rho.pkl')\n",
    "\n",
    "for i in range(5):\n",
    "    # generate WLIG\n",
    "    generated_graph = cell_model.sample_graph()\n",
    "    graph_prime = generated_graph.A\n",
    "    edge_index_prime = torch.tensor(graph_prime.nonzero(), dtype=torch.long)\n",
    "    x = embeddings\n",
    "    data_prime = Data(x=x, edge_index = edge_index_prime)\n",
    "    out = model(data_prime)\n",
    "    src, dst = edge_index_prime\n",
    "    score = (out[src] * out[dst]).sum(dim=-1)\n",
    "    weight = score.detach().numpy()\n",
    "    weight[weight <= 1] = 1\n",
    "    weight = np.rint(weight).astype(int)\n",
    "    weighted_graph_prime = np.copy(graph_prime)\n",
    "    weighted_graph_prime[weighted_graph_prime.nonzero()] = weight\n",
    "\n",
    "    # sample cliques\n",
    "    cliques = compute_cliques(graphs, config, logger)\n",
    "    sample_cliques_table = cliques['children_cliques_train']\n",
    "    # print(sample_cliques_table)\n",
    "    sample_cliques = []\n",
    "    for v in sample_cliques_table.values():\n",
    "        sample_cliques = sample_cliques + v\n",
    "    sample_cliques = [list(c) for c in sample_cliques]\n",
    "    set_sample_cliques = list(set([tuple(sorted(e)) for e in sample_cliques]))\n",
    "    print(f'len of origin: {len(sample_cliques)}, len of deduplicates: {len(set_sample_cliques)}')\n",
    "\n",
    "    # reconstruct hyperedges\n",
    "    reconstruct_hyperedges = utils.lazy_clique_edge_cover(np.copy(weighted_graph_prime), set_sample_cliques, len(graphs['simplicies_train']))\n",
    "    # random.shuffle(set_sample_hyperedges)\n",
    "    # sample_clique_sizes = [len(c) for c in set_sample_cliques]\n",
    "    # data = np.array(sample_clique_sizes)\n",
    "    # hist, bins = np.histogram(data, bins=np.linspace(0, 5, 6))\n",
    "    # sns.displot(data)\n",
    "    # reconstruct_hyperedges = utils.lazy_clique_edge_cover(weighted_adjacency_matrix, set_sample_cliques, len(graphs['simplicies_train']))\n",
    "    # reconstruct_hyperedges_sizes = [len(e) for e in reconstruct_hyperedges]\n",
    "    # data = np.array(reconstruct_hyperedges_sizes)\n",
    "    # sns.displot(data)\n",
    "    set_reconstruct_hyperedges = set([tuple(sorted(e)) for e in reconstruct_hyperedges])\n",
    "    print(f'len: {len(graphs['simplicies_train'])}, {graphs['simplicies_train']}')\n",
    "    print(f'len: {len(set_reconstruct_hyperedges)}, {set_reconstruct_hyperedges}')\n",
    "    save_hypergraph(set_reconstruct_hyperedges, f'./baseline/HyperPLR/{config['dataset']}/reconstruct_hyperedges_{i}.txt')\n",
    "\n",
    "\n",
    "print('original hypergraph', hypergraph_metrics(graphs['simplicies_train']))\n",
    "print('reconstructed hypergraph', hypergraph_metrics(set_reconstruct_hyperedges))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import networkx as nx\n",
    "import community\n",
    "import itertools\n",
    "from collections import defaultdict\n",
    "import pickle\n",
    "\n",
    "\n",
    "def hypergraph_metrics(hg):\n",
    "    # original hypergraph\n",
    "    num_edges = len(hg)\n",
    "    nodes = set()\n",
    "    node_degrees = {}\n",
    "    for edge in hg:\n",
    "        for node in edge:\n",
    "            nodes.add(node)\n",
    "            node_degrees[node] = node_degrees.get(node, 0) + 1\n",
    "    num_nodes = len(nodes)\n",
    "    \n",
    "    # density\n",
    "    density = num_edges / num_nodes\n",
    "\n",
    "    # Average size\n",
    "    avg_size = sum(len(edge) for edge in hg) / num_edges\n",
    "\n",
    "    # Average degree\n",
    "    avg_degree = sum(node_degrees.values()) / num_nodes\n",
    "\n",
    "\n",
    "    # projected graph\n",
    "    G = nx.Graph()\n",
    "    # Add all nodes from the hypergraph\n",
    "    nodes = set(node for edge in hg for node in edge)\n",
    "    G.add_nodes_from(nodes)\n",
    "    # For each hyperedge, create a clique\n",
    "    for edge in hg:\n",
    "        # Add edges between all pairs of nodes in the hyperedge\n",
    "        G.add_edges_from(itertools.combinations(edge, 2))\n",
    "    \n",
    "    part_G = community.best_partition(G)\n",
    "    mod_G = community.modularity(part_G, G)\n",
    "\n",
    "\n",
    "    # bipartite graph\n",
    "    B = nx.Graph()\n",
    "    # Add nodes for the original vertices (left set)\n",
    "    left_nodes = set(node for edge in hg for node in edge)\n",
    "    B.add_nodes_from(left_nodes, bipartite=0)\n",
    "    # Add nodes for the hyperedges (right set)\n",
    "    right_nodes = [f'e{i}' for i in range(len(hg))]\n",
    "    B.add_nodes_from(right_nodes, bipartite=1)\n",
    "    # Add edges between vertices and their corresponding hyperedges\n",
    "    for i, edge in enumerate(hg):\n",
    "        for node in edge:\n",
    "            B.add_edge(node, f'e{i}')\n",
    "\n",
    "\n",
    "    part_B = community.best_partition(B)\n",
    "    mod_B = community.modularity(part_B, B)\n",
    "\n",
    "    return {\n",
    "        \"density\": density,\n",
    "        \"average_size\": avg_size,\n",
    "        \"average_degree\": avg_degree,\n",
    "        \"coefficient\": nx.average_clustering(G),\n",
    "        \"G_modularity\": mod_G,\n",
    "        \"B_modularity\": mod_B\n",
    "    }\n",
    "\n",
    "def load_hypergraph(path, model):\n",
    "    with open(path, 'r') as f:\n",
    "        hg = f.readlines()\n",
    "    if model == 'HyperDK00' or model == 'HyperDK11' or model == 'HyperPLR':\n",
    "        hg = [list(map(int, e.split())) for e in hg]\n",
    "    else:\n",
    "        hg = [list(map(int, e.split(','))) for e in hg]\n",
    "    return hg\n",
    "\n",
    "metric_baseline = defaultdict(list)\n",
    "\n",
    "\n",
    "def get_metrics_baseline(graph_path):\n",
    "    models = os.listdir(graph_path)\n",
    "    for model in models:\n",
    "        graphs = os.listdir(f'{graph_path}/{model}')\n",
    "        for graph in graphs:\n",
    "            hypergraphs = os.listdir(f'{graph_path}/{model}/{graph}')\n",
    "            for hypergraph in hypergraphs:\n",
    "                hg = load_hypergraph(f'{graph_path}/{model}/{graph}/{hypergraph}', model)\n",
    "                metric = hypergraph_metrics(hg)\n",
    "                print(metric)\n",
    "                metric_baseline[(graph, model)].append(metric)\n",
    "\n",
    "    return metric_baseline\n",
    "\n",
    "        # for hypergraphs in gen_model:\n",
    "        #     for hg_file in hypergraphs:\n",
    "        #         hg = load_hypergraph(hg_file)\n",
    "        #         metric = hypergraph_metrics(hg)\n",
    "        #         print(metric)\n",
    "\n",
    "metric_baseline = get_metrics_baseline('./generate_graphs')\n",
    "metric_baseline\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# pickle.dump(metric_baseline, open('./metric_baseline.pkl', 'wb'))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "hygen",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.1"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
