{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bcadbc8d",
   "metadata": {},
   "outputs": [],
   "source": [
    "########################################\n",
    "# Single Cell Code for Multi-Seed Training\n",
    "########################################\n",
    "import sys\n",
    "import os \n",
    "sys.path.append('../../')\n",
    "sys.path.append('../')\n",
    "from time import time\n",
    "import logging\n",
    "import os.path as osp\n",
    "import numpy as np\n",
    "import torch\n",
    "import torch.nn.functional as F\n",
    "from torch_geometric.datasets import TUDataset\n",
    "from torch_geometric.data import DataLoader\n",
    "from torch_geometric.utils import degree\n",
    "from torch.autograd import Variable\n",
    "import random\n",
    "from torch.optim.lr_scheduler import StepLR\n",
    "\n",
    "# Import your utilities and models\n",
    "from gmixupUtils import stat_graph, split_class_graphs, align_graphs\n",
    "from gmixupUtils import two_graphons_mixup, universal_svd, get_graphon\n",
    "from gmixupUtils import GIN\n",
    "from Moment.tools import motifs_to_induced_motifs, lipmlp, train_momentnet\n",
    "from SIGL.tools import *\n",
    "import networkx as nx\n",
    "from torch_geometric.utils import dense_to_sparse\n",
    "import subprocess as sp\n",
    "from scipy.special import comb\n",
    "seed = 19\n",
    "random.seed(seed)\n",
    "np.random.seed(seed)\n",
    "torch.manual_seed(seed)\n",
    "torch.cuda.manual_seed_all(seed)\n",
    "os.environ['PYTHONHASHSEED'] = str(seed)\n",
    "torch.cuda.manual_seed(seed)\n",
    "torch.backends.cudnn.deterministic = True\n",
    "torch.backends.cudnn.benchmark = False\n",
    "os.environ[\"CUBLAS_WORKSPACE_CONFIG\"] = \":4096:8\"\n",
    "print(f\"INFO: Set CUBLAS_WORKSPACE_CONFIG to {os.environ.get('CUBLAS_WORKSPACE_CONFIG')}\")\n",
    "torch.use_deterministic_algorithms(True)\n",
    "\n",
    "ORCA_DIR = '../../orca/'\n",
    "\n",
    "def edge_list_reindexed(G):\n",
    "    idx = 0\n",
    "    id2idx = dict()\n",
    "    for u in G.nodes():\n",
    "        id2idx[str(u)] = idx\n",
    "        idx += 1\n",
    "\n",
    "    edges = []\n",
    "    for (u, v) in G.edges():\n",
    "        edges.append((id2idx[str(u)], id2idx[str(v)]))\n",
    "    return edges\n",
    "\n",
    "\n",
    "def orca(graph):\n",
    "    tmp_file_path = os.path.join(ORCA_DIR, f'tmptmp-{random.random():.4f}.txt')\n",
    "    f = open(tmp_file_path, 'w+')\n",
    "    f.write(str(graph.number_of_nodes()) + ' ' + str(graph.number_of_edges()) + '\\n')\n",
    "    for (u, v) in edge_list_reindexed(graph):\n",
    "        f.write(str(u) + ' ' + str(v) + '\\n')\n",
    "    f.close()\n",
    "\n",
    "\n",
    "    output = sp.check_output([\"../../orca/orca\",'4', tmp_file_path, 'outputnew.txt'])\n",
    "    with open('outputnew.txt', 'r') as file:\n",
    "        output = file.read()\n",
    "    output = output.strip()\n",
    "    node_orbit_counts = np.array([list(map(int, node_cnts.strip().split(' ')))\n",
    "                                  for node_cnts in output.strip('\\n').split('\\n')])\n",
    "    try:\n",
    "        os.remove(tmp_file_path)\n",
    "    except OSError:\n",
    "        pass\n",
    "\n",
    "    return node_orbit_counts\n",
    "\n",
    "\n",
    "def count2density(node_orbit_counts, graph_size):\n",
    "\n",
    "\n",
    "    all_possible_motifs = {}\n",
    "    for size in [2, 3, 4]:\n",
    "        all_possible_motifs[size] = comb(graph_size, size, exact=True)\n",
    "\n",
    "    map_loc2motif = [1, 2, 1, 2, 2, 1, 3, 2, 1]\n",
    "    node_size = np.array(1*[2] + 2*[3] + 6*[4])\n",
    "    rewiring_normalizer = [ 1.,  3.,  1., 12.,  4.,  3., 12.,  6.,  1.]\n",
    "    non_unique_count = np.zeros(9)\n",
    "    density = np.zeros(9)\n",
    "\n",
    "    # summing over nodes\n",
    "    count_over_nodes = np.sum(node_orbit_counts, axis=0)\n",
    "    non_unique_count[0] = count_over_nodes[0]\n",
    "    for i in range(1, 9):\n",
    "      start_idx = sum(map_loc2motif[:i])\n",
    "      non_unique_count[i] = sum(count_over_nodes[start_idx: start_idx+map_loc2motif[i]])\n",
    "\n",
    "    unique_count = non_unique_count / node_size\n",
    "\n",
    "    for i in range(9):\n",
    "      density[i] = unique_count[i] / (rewiring_normalizer[i] * all_possible_motifs[node_size[i]])\n",
    "    return density\n",
    "\n",
    "\n",
    "\n",
    "def moment_graphon(resolution, net):\n",
    "    # Generate data for the plot\n",
    "    x = np.linspace(0, 1, resolution)\n",
    "    y = np.linspace(0, 1, resolution)\n",
    "    X, Y = np.meshgrid(x, y)\n",
    "\n",
    "    # Network output\n",
    "    inputs = torch.tensor(np.stack((X.flatten(), Y.flatten()), axis=1), dtype=torch.float32).to(device)\n",
    "    Z_net = net(inputs).cpu().detach().numpy().reshape(X.shape)\n",
    "    Z_sym = np.copy(Z_net)\n",
    "\n",
    "    # Copy lower triangle to upper triangle (excluding the diagonal)\n",
    "    #i_lower = np.tril_indices(Z_net.shape[0], -1)\n",
    "    #Z_sym[i_lower] = Z_sym.T[i_lower]\n",
    "\n",
    "    # Copy upper triangle to lower triangle (excluding the diagonal)\n",
    "    i_upper = np.triu_indices(Z_net.shape[0], 1)\n",
    "    Z_sym[i_upper] = Z_sym.T[i_upper]\n",
    "\n",
    "\n",
    "    np.fill_diagonal(Z_sym, 0)\n",
    "    #np.fill_diagonal(Z_sym, np.diag(Z_real))\n",
    "\n",
    "    # replace 1 with 0.8 in Z_sym\n",
    "    \n",
    "    #print(\"GW Distance  = \" + str(gw_distance(Z_sym, Z_real)))\n",
    "    return Z_sym\n",
    "\n",
    "def train_Moment(moments, weights, model_name):\n",
    "    \"\"\"\n",
    "    Train the Moment model using the dataset and return GW loss, centrality NMSE averages, and standard deviations.\n",
    "\n",
    "    Args:\n",
    "        dataset_name (str): Name of the dataset file in the dataset folder.\n",
    "        graphon_idx (int): Index of the graphon.\n",
    "\n",
    "    Returns:\n",
    "        tuple: Avg GW loss, Std GW loss, Avg and Std NMSE for each centrality measure.\n",
    "    \"\"\"\n",
    "\n",
    "    # Default parameters\n",
    "    epochs = 2000\n",
    "    #patience = 600\n",
    "    #lr = 1e-3\n",
    "    #N = 30000\n",
    "    #hid_dim = 64\n",
    "    num_motifs = 9\n",
    "    weight_mode = 0\n",
    "    #lr = 9.26e-05\n",
    "    lr = 1e-4\n",
    "    N = 20000\n",
    "    hid_dim = 5*[96]\n",
    "    num_layers = 5\n",
    "    \n",
    "    patience = 600\n",
    "    w0 = 9.\n",
    "\n",
    "  \n",
    "\n",
    "    Es = [[[(0, 1)]], [[(0, 1), (1, 2)], [(0, 1), (0, 2), (1, 2)]], [[(0, 1), (1, 2), (2, 3)], [(0, 1), (0, 2), (0, 3)], [(0, 1), (0, 2), (1, 3), (2, 3)], [(0, 1), (0, 2), (0, 3), (1, 2)],\n",
    "                                                                 [(0, 1), (0, 2), (1, 2), (1, 3), (2, 3)], [(0, 1), (0, 2), (0, 3), (1, 2), (1, 3), (2, 3)]],\n",
    "      [[(0, 1), (1, 2), (2, 3), (3, 4)], [(0, 1), (0, 2), (0, 3), (3, 4)], [(0, 1), (0, 2), (0, 3), (0, 4)], [(0, 1), (0, 2), (1, 2), (1, 3), (2, 4)], [(0, 1), (0, 2), (1, 2), (0, 3), (3, 4)],\n",
    "       [(0, 1), (0, 2), (0, 3), (0, 4), (3, 4)], [(0, 1), (0, 2), (1, 3), (2, 4), (3, 4)], [(0, 1), (0, 2), (0, 3), (2, 4), (3, 4)], [(0, 1), (0, 2), (0, 3), (2, 4), (3, 4), (0, 4)], [(0, 1), (0, 2), (0, 3), (0, 4), (1, 2), (3, 4)] ,\n",
    "       [(0, 1), (0, 2), (0, 3), (2, 4), (3, 4), (2, 3)], [(0, 1), (0, 2), (0, 3), (1, 4), (2, 4), (3, 4)], [(0, 1), (0, 2), (1, 2), (1, 3), (2, 4), (3, 4)], [(0, 1), (0, 2), (1, 2), (1, 3), (2, 3), (1, 4), (2, 4)],\n",
    "       [(0, 1), (0, 2), (0, 3), (1, 2), (1, 3), (2, 3), (2, 4)], [(0, 1), (0, 2), (0, 3), (0, 4), (1, 2), (2, 3), (3, 4)], [(0, 1), (0, 2), (0, 3), (1, 2), (2, 3), (1, 4), (3, 4)], [(0, 1), (0, 2), (0, 3), (1, 3), (1, 2), (2, 3), (2, 4), (3, 4)],\n",
    "       [(0, 1), (0, 2), (0, 3), (0, 4), (1, 2), (1, 3), (2, 4), (3, 4)], [(0, 1), (0, 2), (0, 3), (0, 4), (1, 2), (2, 3), (1, 3), (2, 4), (3, 4)], [(0, 1), (0, 2), (0, 3), (0, 4), (1, 2), (1, 3), (1, 4), (2, 3), (2, 4), (3, 4)]\n",
    "       ]]\n",
    "\n",
    "    induced_list = motifs_to_induced_motifs(Es)\n",
    "    device = torch.device('cuda:2' if torch.cuda.is_available() else 'cpu')\n",
    "\n",
    "    real_moments = torch.tensor(moments).to(device)\n",
    "    # print density counting is finished\n",
    "    print(\"Density counting is finished\")\n",
    "    print(real_moments)\n",
    "    while True:\n",
    "        \n",
    "        if model_name == 'SIREN':\n",
    "            model = SirenNet(2, hid_dim, 1, num_layers=num_layers, w0=w0, w0_initial=30.).train().to(device)\n",
    "        elif model_name == 'MLP':\n",
    "            lr = 1e-3\n",
    "            hid_dim = 64\n",
    "            model = lipmlp([2, hid_dim, hid_dim, 1]).train().to(device)\n",
    "        losses = train_momentnet(model, induced_list[:num_motifs], real_moments, 4, N, epochs, patience, lr, device, weights)\n",
    "\n",
    "        try:\n",
    "            if (losses[0] - losses[-1])/losses[0] > 1e-3:\n",
    "                break\n",
    "            else:\n",
    "                print((losses[0] - losses[-1])/losses[0])\n",
    "        except:\n",
    "            \n",
    "            continue\n",
    "\n",
    "\n",
    "\n",
    "    return model\n",
    "\n",
    "\n",
    "# Set up logging\n",
    "logger = logging.getLogger()\n",
    "logger.setLevel(logging.DEBUG)\n",
    "formatter = logging.Formatter('%(asctime)s - %(levelname)s: - %(message)s', datefmt='%Y-%m-%d')\n",
    "if not logger.handlers:\n",
    "    # Optionally, log to the screen\n",
    "    ch = logging.StreamHandler()\n",
    "    ch.setLevel(logging.DEBUG)\n",
    "    ch.setFormatter(formatter)\n",
    "    logger.addHandler(ch)\n",
    "logging.getLogger('matplotlib').setLevel(logging.WARNING)\n",
    "\n",
    "\n",
    "########################################\n",
    "# Helper Function Definitions\n",
    "########################################\n",
    "\n",
    "def prepare_dataset_x(dataset):\n",
    "    logger.info(\"Prepare dataset x\")\n",
    "    if dataset[0].x is None:\n",
    "        logger.info(\"dataset[0].x is None\")\n",
    "        max_degree = 0\n",
    "        degs_all = []\n",
    "        for data in dataset:\n",
    "            d = degree(data.edge_index[0], dtype=torch.long)\n",
    "            degs_all.append(d)\n",
    "            max_degree = max(max_degree, d.max().item())\n",
    "            data.num_nodes = int(torch.max(data.edge_index)) + 1\n",
    "\n",
    "        if max_degree < 2000:\n",
    "            for data in dataset:\n",
    "                degs = degree(data.edge_index[0], dtype=torch.long)\n",
    "                data.x = F.one_hot(degs, num_classes=max_degree+1).to(torch.float)\n",
    "        else:\n",
    "            deg = torch.cat(degs_all, dim=0).to(torch.float)\n",
    "            mean, std = deg.mean().item(), deg.std().item()\n",
    "            for data in dataset:\n",
    "                degs = degree(data.edge_index[0], dtype=torch.long)\n",
    "                data.x = ((degs - mean) / std).view(-1, 1)\n",
    "    return dataset\n",
    "\n",
    "\n",
    "def prepare_dataset_onehot_y(dataset):\n",
    "    y_set = set()\n",
    "    for data in dataset:\n",
    "        y_set.add(int(data.y))\n",
    "    num_classes = len(y_set)\n",
    "    for data in dataset:\n",
    "        # Each data.y becomes a one-hot vector; here we take the first element since \n",
    "        # F.one_hot returns a vector wrapped in an extra dimension.\n",
    "        data.y = F.one_hot(data.y, num_classes=num_classes).to(torch.float)[0]\n",
    "    return dataset\n",
    "\n",
    "\n",
    "def mixup_cross_entropy_loss(input, target, size_average=True):\n",
    "    \"\"\"Mixup version of cross entropy loss.\"\"\"\n",
    "    assert input.size() == target.size()\n",
    "    assert isinstance(input, Variable) and isinstance(target, Variable)\n",
    "    loss = - torch.sum(input * target)\n",
    "    return loss / input.size()[0] if size_average else loss\n",
    "\n",
    "\n",
    "def train(model, train_loader, optimizer, device, num_classes):\n",
    "    model.train()\n",
    "    loss_all = 0\n",
    "    graph_all = 0\n",
    "    for data in train_loader:\n",
    "        data = data.to(device)\n",
    "        optimizer.zero_grad()\n",
    "        output = model(data.x, data.edge_index, data.batch)\n",
    "        y = data.y.view(-1, num_classes)\n",
    "        loss = mixup_cross_entropy_loss(output, y)\n",
    "        loss.backward()\n",
    "        loss_all += loss.item() * data.num_graphs\n",
    "        #print(f\"loss : {loss_all}\")\n",
    "        graph_all += data.num_graphs\n",
    "        optimizer.step()\n",
    "    #print(f\"Output : {output}\")\n",
    "    return loss_all / graph_all\n",
    "\n",
    "\n",
    "def test(model, loader, device, num_classes):\n",
    "    model.eval()\n",
    "    correct = 0\n",
    "    total = 0\n",
    "    loss_all = 0\n",
    "    for data in loader:\n",
    "        data = data.to(device)\n",
    "        output = model(data.x, data.edge_index, data.batch)\n",
    "        pred = output.max(dim=1)[1]\n",
    "        y = data.y.view(-1, num_classes)\n",
    "        loss_all += mixup_cross_entropy_loss(output, y).item() * data.num_graphs\n",
    "        \n",
    "        # Convert one-hot to class indices for comparison\n",
    "        y_labels = y.max(dim=1)[1]\n",
    "        correct += pred.eq(y_labels).sum().item()\n",
    "        total += data.num_graphs\n",
    "    return correct / total, loss_all / total\n",
    "\n",
    "\n",
    "\n",
    "def two_moments_mixup(two_moments, la=0.5, num_sample=20, ge='ISGL', resolution=None):\n",
    "\n",
    "    label = la * two_moments[0][0] + (1 - la) * two_moments[1][0]\n",
    "    sample_graph_label = torch.from_numpy(label).type(torch.float32)\n",
    "\n",
    "    \n",
    "    # Res = [int(resolution[0]) for _ in range(num_sample)]\n",
    "    #print(Res)\n",
    "\n",
    "    #print(two_moments[0][1])\n",
    "    #print(two_moments[1][1])\n",
    "    new_moment = la * two_moments[0][1] + (1 - la) * two_moments[1][1]\n",
    "    #sample_graph = (np.random.rand(*new_graphon.shape) <= new_graphon).astype(np.int32)\n",
    "\n",
    "    trained_model = train_Moment(new_moment, 0, \"MLP\")\n",
    "    graphon = moment_graphon(resolution, trained_model)\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "    sample_graphs = []\n",
    "    for i in range(num_sample):\n",
    "        \n",
    "        \n",
    "        done = 0\n",
    "        while True:\n",
    "            if done > 20:\n",
    "                return []\n",
    "            num_nodes = random.randint(300, 600)\n",
    "            sample_graph = simulate_graphs(graphon, seed_gsize =123, seed_edge =123, num_graphs = 1, num_nodes = num_nodes, graph_size = 'fixed', offset =0)[0]\n",
    "\n",
    "            # print number of nodes\n",
    "            #print(f\"Number of nodes: {num_nodes}\")\n",
    "\n",
    "            #sample_graph = np.triu(sample_graph)\n",
    "            #sample_graph = sample_graph + sample_graph.T - np.diag(np.diag(sample_graph))\n",
    "            #sample_graph = sample_graph[sample_graph.sum(axis=1) != 0]\n",
    "            #sample_graph = sample_graph[:, sample_graph.sum(axis=0) != 0]\n",
    "\n",
    "            A = torch.from_numpy(sample_graph)\n",
    "            edge_index, _ = dense_to_sparse(A)\n",
    "            num_nodes = sample_graph.shape[0]\n",
    "\n",
    "            if num_nodes == 0:\n",
    "                print('num_nodes is 0')\n",
    "                done += 1\n",
    "                continue\n",
    "\n",
    "            # continue if the graph is empty in terms of edges\n",
    "            if edge_index.shape[1] == 0:\n",
    "                print('edge_index is 0')\n",
    "                done += 1\n",
    "                continue\n",
    "\n",
    "            break\n",
    "\n",
    "\n",
    "        pyg_graph = Data()\n",
    "        pyg_graph.y = sample_graph_label\n",
    "        pyg_graph.edge_index = edge_index\n",
    "        pyg_graph.num_nodes = num_nodes\n",
    "        sample_graphs.append(pyg_graph)\n",
    "        \n",
    "    return sample_graphs\n",
    "\n",
    "########################################\n",
    "# Default Parameters and Seed List Setup\n",
    "########################################\n",
    "\n",
    "# Default parameters\n",
    "data_path      = \"./\"\n",
    "dataset_name   = \"REDDIT-BINARY\"\n",
    "model_name     = \"GIN\"\n",
    "num_epochs     = 400\n",
    "batch_size     = 128\n",
    "learning_rate  = 0.01\n",
    "num_hidden     = 64\n",
    "lam_range      = [0.1, 0.2]\n",
    "aug_ratio      = 0.2\n",
    "aug_num        = 10\n",
    "ge             = \"Moment\"   # options: \"ISGL\", \"Moment\", \"IGNR\", etc.\n",
    "log_screen     = True\n",
    "gmixup         = True\n",
    "n_epochs_inr   = 20\n",
    "\n",
    "Es = [[[(0, 1)]], [[(0, 1), (1, 2)], [(0, 1), (0, 2), (1, 2)]], [[(0, 1), (1, 2), (2, 3)], [(0, 1), (0, 2), (0, 3)], [(0, 1), (0, 2), (1, 3), (2, 3)], [(0, 1), (0, 2), (0, 3), (1, 2)],\n",
    "                                                                 [(0, 1), (0, 2), (1, 2), (1, 3), (2, 3)], [(0, 1), (0, 2), (0, 3), (1, 2), (1, 3), (2, 3)]],\n",
    "      [[(0, 1), (1, 2), (2, 3), (3, 4)], [(0, 1), (0, 2), (0, 3), (3, 4)], [(0, 1), (0, 2), (0, 3), (0, 4)], [(0, 1), (0, 2), (1, 2), (1, 3), (2, 4)], [(0, 1), (0, 2), (1, 2), (0, 3), (3, 4)],\n",
    "       [(0, 1), (0, 2), (0, 3), (0, 4), (3, 4)], [(0, 1), (0, 2), (1, 3), (2, 4), (3, 4)], [(0, 1), (0, 2), (0, 3), (2, 4), (3, 4)], [(0, 1), (0, 2), (0, 3), (2, 4), (3, 4), (0, 4)], [(0, 1), (0, 2), (0, 3), (0, 4), (1, 2), (3, 4)] ,\n",
    "       [(0, 1), (0, 2), (0, 3), (2, 4), (3, 4), (2, 3)], [(0, 1), (0, 2), (0, 3), (1, 4), (2, 4), (3, 4)], [(0, 1), (0, 2), (1, 2), (1, 3), (2, 4), (3, 4)], [(0, 1), (0, 2), (1, 2), (1, 3), (2, 3), (1, 4), (2, 4)],\n",
    "       [(0, 1), (0, 2), (0, 3), (1, 2), (1, 3), (2, 3), (2, 4)], [(0, 1), (0, 2), (0, 3), (0, 4), (1, 2), (2, 3), (3, 4)], [(0, 1), (0, 2), (0, 3), (1, 2), (2, 3), (1, 4), (3, 4)], [(0, 1), (0, 2), (0, 3), (1, 3), (1, 2), (2, 3), (2, 4), (3, 4)],\n",
    "       [(0, 1), (0, 2), (0, 3), (0, 4), (1, 2), (1, 3), (2, 4), (3, 4)], [(0, 1), (0, 2), (0, 3), (0, 4), (1, 2), (2, 3), (1, 3), (2, 4), (3, 4)], [(0, 1), (0, 2), (0, 3), (0, 4), (1, 2), (1, 3), (1, 4), (2, 3), (2, 4), (3, 4)]\n",
    "       ]]\n",
    "\n",
    "induced_list = motifs_to_induced_motifs(Es)\n",
    "\n",
    "# Example seed list\n",
    "seed_list = [1314, 21314, 31314, 11314, 41314, 51314, 61314, 71314]\n",
    "\n",
    "# To store metrics across seeds\n",
    "all_best_test_acc = []\n",
    "all_last10_avg_acc = []\n",
    "\n",
    "########################################\n",
    "# Main Loop: Iterate over Seed List\n",
    "########################################\n",
    "\n",
    "for seed in seed_list:\n",
    "    logger.info(\"==========================================\")\n",
    "    logger.info(f\"Starting training for seed {seed}\")\n",
    "\n",
    "    # Set seeds for reproducibility\n",
    "    #torch.manual_seed(seed)\n",
    "    np.random.seed(seed)\n",
    "    random.seed(seed)\n",
    "\n",
    "    device = torch.device('cuda:2' if torch.cuda.is_available() else 'cpu')\n",
    "    logger.info(f\"Running on device: {device}\")\n",
    "\n",
    "    # Load dataset\n",
    "    path = osp.join(data_path, dataset_name)\n",
    "    dataset = TUDataset(path, name=dataset_name)\n",
    "    dataset = list(dataset)\n",
    "    \n",
    "    # Reshape labels\n",
    "    for graph in dataset:\n",
    "        graph.y = graph.y.view(-1)\n",
    "    dataset = prepare_dataset_onehot_y(dataset)\n",
    "\n",
    "    # Log global dataset statistics\n",
    "    avg_num_nodes, max_num_nodes, avg_num_edges, avg_density, median_num_nodes, median_num_edges, median_density, min_num_nodes, std_num_nodes = stat_graph(dataset)\n",
    "    logger.info(f\"Dataset {dataset_name}: {len(dataset)} graphs, avg nodes: {avg_num_nodes}\")\n",
    "\n",
    "    # Shuffle and split dataset: 70% train, 10% validation (from train), 20% test\n",
    "    random.shuffle(dataset)\n",
    "    train_nums = int(len(dataset) * 0.7)\n",
    "    train_val_nums = int(len(dataset) * 0.8)\n",
    "    \n",
    "    # Log training subset statistics\n",
    "    avg_num_nodes, max_num_nodes, avg_num_edges, avg_density, median_num_nodes, median_num_edges, median_density, min_num_nodes, std_num_nodes = stat_graph(dataset[:train_nums])\n",
    "    logger.info(f\"Training graphs: avg nodes: {avg_num_nodes}, max nodes: {max_num_nodes}\")\n",
    "\n",
    "    resolution = 1000 #int(median_num_nodes)\n",
    "    logger.info(f\"Resolution set to: {resolution}\")\n",
    "    \n",
    "    # Augment dataset using graphon mixup if enabled\n",
    "    if gmixup:\n",
    "        class_graphs = split_class_graphs(dataset[:train_nums])\n",
    "\n",
    "        ###########################################\n",
    "        def compute_moment_scores(moment_matrix):\n",
    "            \"\"\"\n",
    "            Computes discrimination scores for each moment (column) in a matrix.\n",
    "            \n",
    "            Parameters:\n",
    "                moment_matrix (np.ndarray): A 2D NumPy array of shape (n_classes, m_moments)\n",
    "                                            where each row corresponds to a class.\n",
    "            \n",
    "            Returns:\n",
    "                std_scores (np.ndarray): Standard deviation scores for each moment.\n",
    "                range_scores (np.ndarray): Range (max-min) scores for each moment.\n",
    "                avg_pairwise (np.ndarray): Average pairwise difference scores for each moment.\n",
    "            \"\"\"\n",
    "            # Method 1: Standard deviation across classes for each moment.\n",
    "            std_scores = np.std(moment_matrix, axis=0)\n",
    "            \n",
    "            # Method 2: Range (maximum minus minimum) for each moment.\n",
    "            range_scores = np.ptp(moment_matrix, axis=0)  # np.ptp returns the range (peak to peak)\n",
    "            \n",
    "            # Method 3: Average pairwise absolute differences between classes for each moment.\n",
    "            n, m = moment_matrix.shape\n",
    "            avg_pairwise = np.zeros(m)\n",
    "            for j in range(m):\n",
    "                pairwise_diffs = []\n",
    "                for i in range(n):\n",
    "                    for k in range(i + 1, n):\n",
    "                        diff = abs(moment_matrix[i, j] - moment_matrix[k, j])\n",
    "                        pairwise_diffs.append(diff)\n",
    "                avg_pairwise[j] = np.mean(pairwise_diffs)\n",
    "                \n",
    "            return std_scores, range_scores, avg_pairwise\n",
    "\n",
    "\n",
    "\n",
    "        Es = [[[(0, 1)]], [[(0, 1), (1, 2)], [(0, 1), (0, 2), (1, 2)]], [[(0, 1), (1, 2), (2, 3)], [(0, 1), (0, 2), (0, 3)], [(0, 1), (0, 2), (1, 3), (2, 3)], [(0, 1), (0, 2), (0, 3), (1, 2)],\n",
    "                                                                        [(0, 1), (0, 2), (1, 2), (1, 3), (2, 3)], [(0, 1), (0, 2), (0, 3), (1, 2), (1, 3), (2, 3)]],\n",
    "            [[(0, 1), (1, 2), (2, 3), (3, 4)], [(0, 1), (0, 2), (0, 3), (3, 4)], [(0, 1), (0, 2), (0, 3), (0, 4)], [(0, 1), (0, 2), (1, 2), (1, 3), (2, 4)], [(0, 1), (0, 2), (1, 2), (0, 3), (3, 4)],\n",
    "            [(0, 1), (0, 2), (0, 3), (0, 4), (3, 4)], [(0, 1), (0, 2), (1, 3), (2, 4), (3, 4)], [(0, 1), (0, 2), (0, 3), (2, 4), (3, 4)], [(0, 1), (0, 2), (0, 3), (2, 4), (3, 4), (0, 4)], [(0, 1), (0, 2), (0, 3), (0, 4), (1, 2), (3, 4)] ,\n",
    "            [(0, 1), (0, 2), (0, 3), (2, 4), (3, 4), (2, 3)], [(0, 1), (0, 2), (0, 3), (1, 4), (2, 4), (3, 4)], [(0, 1), (0, 2), (1, 2), (1, 3), (2, 4), (3, 4)], [(0, 1), (0, 2), (1, 2), (1, 3), (2, 3), (1, 4), (2, 4)],\n",
    "            [(0, 1), (0, 2), (0, 3), (1, 2), (1, 3), (2, 3), (2, 4)], [(0, 1), (0, 2), (0, 3), (0, 4), (1, 2), (2, 3), (3, 4)], [(0, 1), (0, 2), (0, 3), (1, 2), (2, 3), (1, 4), (3, 4)], [(0, 1), (0, 2), (0, 3), (1, 3), (1, 2), (2, 3), (2, 4), (3, 4)],\n",
    "            [(0, 1), (0, 2), (0, 3), (0, 4), (1, 2), (1, 3), (2, 4), (3, 4)], [(0, 1), (0, 2), (0, 3), (0, 4), (1, 2), (2, 3), (1, 3), (2, 4), (3, 4)], [(0, 1), (0, 2), (0, 3), (0, 4), (1, 2), (1, 3), (1, 4), (2, 3), (2, 4), (3, 4)]\n",
    "            ]]\n",
    "\n",
    "        induced_list = motifs_to_induced_motifs(Es)\n",
    "\n",
    "        # Suppose we have data from 3 classes and 4 moments per class:\n",
    "        moment_matrix = []\n",
    "\n",
    "        for label, graphs in class_graphs:\n",
    "            estimated_densities = np.zeros(9)\n",
    "            nx_graphs = [nx.from_numpy_array(graph) for graph in graphs]\n",
    "            for graph in nx_graphs:\n",
    "                # only if number of nodes is greater than 6\n",
    "                if graph.number_of_nodes() < 6:\n",
    "                    continue\n",
    "\n",
    "                #if graph.number_of_nodes() < resolution:\n",
    "                #    continue\n",
    "                node_orbit_counts = orca(graph)\n",
    "                density = count2density(node_orbit_counts, graph.number_of_nodes())\n",
    "                estimated_densities += density\n",
    "            #print(f\"Estimated densities for label {label}: {estimated_densities}\")\n",
    "            real_moments = estimated_densities / len(graphs)\n",
    "            moment_matrix.append([label, real_moments])\n",
    "        \n",
    "        seed_torch = 0\n",
    "        torch.manual_seed(seed_torch)\n",
    "        torch.cuda.manual_seed(seed_torch)\n",
    "        torch.cuda.manual_seed_all(seed_torch)\n",
    "        torch.backends.cudnn.deterministic = True\n",
    "        torch.backends.cudnn.benchmark = False\n",
    "        os.environ[\"CUBLAS_WORKSPACE_CONFIG\"] = \":4096:8\"\n",
    "        torch.use_deterministic_algorithms(True)   \n",
    "        num_sample = int(train_nums * aug_ratio / aug_num)\n",
    "        lam_list = np.random.uniform(low=lam_range[0], high=lam_range[1], size=(aug_num,))\n",
    "        new_graph = []\n",
    "        #lam_list = [0.18]\n",
    "        for lam in lam_list:\n",
    "            logger.info(f\"lam: {lam}\")\n",
    "            logger.info(f\"num_sample: {num_sample}\")\n",
    "            two_moments = random.sample(moment_matrix, 2)\n",
    "            \n",
    "            new_graph += two_moments_mixup(two_moments, la=lam, num_sample=num_sample, ge=\"USVT\", resolution=resolution)\n",
    "            #logger.info(f\"New graph label: {new_graph[-1].y}\")\n",
    "            \n",
    "        avg_num_nodes, max_num_nodes, avg_num_edges, avg_density, median_num_nodes, median_num_edges, median_density, min_num_nodes, std_num_nodes = stat_graph(new_graph)\n",
    "        logger.info(f\"New graphs: avg nodes: {avg_num_nodes}, min nodes: {min_num_nodes}, max nodes: {max_num_nodes}\")\n",
    "        logger.info(f\"Real augmentation ratio: {len(new_graph)/train_nums}\")\n",
    "        dataset = new_graph + dataset\n",
    "        train_nums += len(new_graph)\n",
    "        train_val_nums += len(new_graph)\n",
    "\n",
    "    # Prepare node features for whole dataset\n",
    "    dataset = prepare_dataset_x(dataset)\n",
    "    logger.info(f\"Dataset feature shape: {dataset[0].x.shape}\")\n",
    "    logger.info(f\"Dataset label shape: {dataset[0].y.shape}\")\n",
    "\n",
    "    num_features = dataset[0].x.shape[1]\n",
    "    num_classes = dataset[0].y.shape[0]\n",
    "\n",
    "    for data in dataset:\n",
    "        if not hasattr(data, 'edge_attr') or data.edge_attr is None:\n",
    "            data.edge_attr = torch.ones((data.edge_index.size(1), 3)) \n",
    "\n",
    "    # Split dataset into train, validation and test\n",
    "    train_dataset = dataset[:train_nums]\n",
    "    random.shuffle(train_dataset)\n",
    "    val_dataset = dataset[train_nums:train_val_nums]\n",
    "    test_dataset = dataset[train_val_nums:]\n",
    "    \n",
    "    logger.info(f\"Train dataset size: {len(train_dataset)}\")\n",
    "    logger.info(f\"Validation dataset size: {len(val_dataset)}\")\n",
    "    logger.info(f\"Test dataset size: {len(test_dataset)}\")\n",
    "\n",
    "    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)\n",
    "    val_loader = DataLoader(val_dataset, batch_size=batch_size)\n",
    "    test_loader = DataLoader(test_dataset, batch_size=batch_size)\n",
    "\n",
    "    # Instantiate the model\n",
    "    if model_name == \"GIN\":\n",
    "        model_instance = GIN(num_features=num_features, num_classes=num_classes, num_hidden=num_hidden).to(device)\n",
    "    else:\n",
    "        logger.info(\"No valid model specified.\")\n",
    "        continue\n",
    "\n",
    "    optimizer = torch.optim.Adam(model_instance.parameters(), lr=learning_rate, weight_decay=5e-4)\n",
    "    scheduler = StepLR(optimizer, step_size=100, gamma=0.5)\n",
    "\n",
    "\n",
    "\n",
    "    # Train the model\n",
    "    last_10_epoch_acc = []  # To record test acc for the last 10 epochs\n",
    "    best_val_acc = 0\n",
    "    best_test_acc = 0\n",
    "    #for name, param in model_instance.named_parameters():\n",
    "    #    if param.requires_grad:\n",
    "    #        logger.info(f\"  Layer: {name}, Shape: {param.shape}, Sum: {torch.sum(param.data).item():.6f}, Mean: {torch.mean(param.data).item():.6f}\")\n",
    "    for epoch in range(1, num_epochs):\n",
    "    \n",
    "        train_loss = train(model_instance, train_loader, optimizer, device, num_classes)\n",
    "\n",
    "        #for name, param in model_instance.named_parameters():\n",
    "        #    if param.requires_grad:\n",
    "        #        logger.info(f\"  Layer: {name}, Shape: {param.shape}, Sum: {torch.sum(param.data).item():.6f}, Mean: {torch.mean(param.data).item():.6f}\")\n",
    "        #break\n",
    "        val_acc, val_loss = test(model_instance, val_loader, device, num_classes)\n",
    "        test_acc, test_loss = test(model_instance, test_loader, device, num_classes)\n",
    "        scheduler.step()\n",
    "\n",
    "        if val_acc >= best_val_acc and epoch > 125:\n",
    "            best_val_acc = val_acc\n",
    "            best_test_acc = test_acc\n",
    "            best_epoch = epoch\n",
    "\n",
    "        if epoch > num_epochs - 10:\n",
    "            last_10_epoch_acc.append(test_acc)\n",
    "\n",
    "        if epoch % 30 == 0:\n",
    "            print('Epoch: {:03d}, Train Loss: {:.6f}, Val Loss: {:.6f}, Test Loss: {:.6f}, Val Acc: {:.6f}, Test Acc: {:.6f}'.format(\n",
    "                epoch, train_loss, val_loss, test_loss, val_acc, test_acc))\n",
    "    \n",
    "    avg_last10 = np.mean(last_10_epoch_acc)\n",
    "    print(f\"Seed {seed}: Last 10 epochs average test acc: {np.round(avg_last10,3)}\")\n",
    "    print(f\"Seed {seed}: Best test acc (based on validation) = {best_test_acc} at epoch {best_epoch}\")\n",
    "\n",
    "    all_best_test_acc.append(best_test_acc)\n",
    "    all_last10_avg_acc.append(avg_last10)\n",
    "\n",
    "########################################\n",
    "# After All Seeds: Print Final Metrics\n",
    "########################################\n",
    "print(\"==========================================\")\n",
    "print(\"Overall Results Across Seeds:\")\n",
    "print(f\"Best Test Accuracy: Mean = {np.mean(all_best_test_acc):.3f}, Std = {np.std(all_best_test_acc):.3f}\")\n",
    "print(f\"Last 10 Epoch Test Accuracy Averag e: Mean = {np.mean(all_last10_avg_acc):.3f}, Std = {np.std(all_last10_avg_acc):.3f}\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6f9728c3",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "base",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
