{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bcadbc8d",
   "metadata": {},
   "outputs": [],
   "source": [
    "########################################\n",
    "# Single Cell Code for Multi-Seed Training\n",
    "########################################\n",
    "import sys\n",
    "import os \n",
    "sys.path.append('../../')\n",
    "sys.path.append('../')\n",
    "from time import time\n",
    "import logging\n",
    "import os.path as osp\n",
    "import numpy as np\n",
    "import torch\n",
    "import torch.nn.functional as F\n",
    "from torch_geometric.datasets import TUDataset\n",
    "from torch_geometric.data import DataLoader\n",
    "from torch_geometric.utils import degree\n",
    "from torch.autograd import Variable\n",
    "import random\n",
    "from torch.optim.lr_scheduler import StepLR\n",
    "import schedulefree\n",
    "\n",
    "# Import your utilities and models\n",
    "from gmixupUtils import stat_graph, split_class_graphs, align_graphs\n",
    "from gmixupUtils import two_graphons_mixup, universal_svd, get_graphon\n",
    "from gmixupUtils import GIN\n",
    "from Moment.tools import motifs_to_induced_motifs, lipmlp, train_momentnet\n",
    "from SIGL.tools import *\n",
    "import networkx as nx\n",
    "from torch_geometric.utils import dense_to_sparse\n",
    "import subprocess as sp\n",
    "from scipy.special import comb\n",
    "seed = 21\n",
    "random.seed(seed)\n",
    "torch.cuda.manual_seed(seed)\n",
    "torch.cuda.manual_seed_all(seed)\n",
    "os.environ['PYTHONHASHSEED'] = str(seed)\n",
    "np.random.seed(seed)\n",
    "torch.manual_seed(seed)\n",
    "torch.backends.cudnn.deterministic = True\n",
    "torch.backends.cudnn.benchmark = False\n",
    "os.environ[\"CUBLAS_WORKSPACE_CONFIG\"] = \":4096:8\"\n",
    "print(f\"INFO: Set CUBLAS_WORKSPACE_CONFIG to {os.environ.get('CUBLAS_WORKSPACE_CONFIG')}\")\n",
    "torch.use_deterministic_algorithms(True)\n",
    "\n",
    "ORCA_DIR = '../../orca/'\n",
    "\n",
    "def edge_list_reindexed(G):\n",
    "    idx = 0\n",
    "    id2idx = dict()\n",
    "    for u in G.nodes():\n",
    "        id2idx[str(u)] = idx\n",
    "        idx += 1\n",
    "\n",
    "    edges = []\n",
    "    for (u, v) in G.edges():\n",
    "        edges.append((id2idx[str(u)], id2idx[str(v)]))\n",
    "    return edges\n",
    "\n",
    "\n",
    "def orca(graph):\n",
    "    tmp_file_path = os.path.join(ORCA_DIR, f'tmptmp55-{random.random():.4f}.txt')\n",
    "    f = open(tmp_file_path, 'w+')\n",
    "    f.write(str(graph.number_of_nodes()) + ' ' + str(graph.number_of_edges()) + '\\n')\n",
    "    for (u, v) in edge_list_reindexed(graph):\n",
    "        f.write(str(u) + ' ' + str(v) + '\\n')\n",
    "    f.close()\n",
    "\n",
    "\n",
    "    output = sp.check_output([\"../../orca/orca\",'4', tmp_file_path, 'outputnew.txt'])\n",
    "    with open('outputnew.txt', 'r') as file:\n",
    "        output = file.read()\n",
    "    output = output.strip()\n",
    "    node_orbit_counts = np.array([list(map(int, node_cnts.strip().split(' ')))\n",
    "                                  for node_cnts in output.strip('\\n').split('\\n')])\n",
    "    try:\n",
    "        os.remove(tmp_file_path)\n",
    "    except OSError:\n",
    "        pass\n",
    "\n",
    "    return node_orbit_counts\n",
    "\n",
    "\n",
    "def count2density(node_orbit_counts, graph_size):\n",
    "\n",
    "\n",
    "    all_possible_motifs = {}\n",
    "    for size in [2, 3, 4]:\n",
    "        all_possible_motifs[size] = comb(graph_size, size, exact=True)\n",
    "\n",
    "    map_loc2motif = [1, 2, 1, 2, 2, 1, 3, 2, 1]\n",
    "    node_size = np.array(1*[2] + 2*[3] + 6*[4])\n",
    "    rewiring_normalizer = [ 1.,  3.,  1., 12.,  4.,  3., 12.,  6.,  1.]\n",
    "    non_unique_count = np.zeros(9)\n",
    "    density = np.zeros(9)\n",
    "\n",
    "    # summing over nodes\n",
    "    count_over_nodes = np.sum(node_orbit_counts, axis=0)\n",
    "    non_unique_count[0] = count_over_nodes[0]\n",
    "    for i in range(1, 9):\n",
    "      start_idx = sum(map_loc2motif[:i])\n",
    "      non_unique_count[i] = sum(count_over_nodes[start_idx: start_idx+map_loc2motif[i]])\n",
    "\n",
    "    unique_count = non_unique_count / node_size\n",
    "\n",
    "    for i in range(9):\n",
    "      density[i] = unique_count[i] / (rewiring_normalizer[i] * all_possible_motifs[node_size[i]])\n",
    "    return density\n",
    "\n",
    "\n",
    "\n",
    "def moment_graphon(resolution, net):\n",
    "    # Generate data for the plot\n",
    "    x = np.linspace(0, 1, resolution)\n",
    "    y = np.linspace(0, 1, resolution)\n",
    "    X, Y = np.meshgrid(x, y)\n",
    "\n",
    "    # Network output\n",
    "    inputs = torch.tensor(np.stack((X.flatten(), Y.flatten()), axis=1), dtype=torch.float32).to(device)\n",
    "    Z_net = net(inputs).cpu().detach().numpy().reshape(X.shape)\n",
    "    Z_sym = np.copy(Z_net)\n",
    "\n",
    "    # Copy lower triangle to upper triangle (excluding the diagonal)\n",
    "    #i_lower = np.tril_indices(Z_net.shape[0], -1)\n",
    "    #Z_sym[i_lower] = Z_sym.T[i_lower]\n",
    "\n",
    "    # Copy upper triangle to lower triangle (excluding the diagonal)\n",
    "    i_upper = np.triu_indices(Z_net.shape[0], 1)\n",
    "    Z_sym[i_upper] = Z_sym.T[i_upper]\n",
    "\n",
    "\n",
    "    np.fill_diagonal(Z_sym, 0)\n",
    "    #np.fill_diagonal(Z_sym, np.diag(Z_real))\n",
    "\n",
    "    # replace 1 with 0.8 in Z_sym\n",
    "    \n",
    "    #print(\"GW Distance  = \" + str(gw_distance(Z_sym, Z_real)))\n",
    "    return Z_sym\n",
    "\n",
    "def visualize_nn_graphon(model, resolution=100, device=None, title=\"Neural Network Graphon Visualization\"):\n",
    "    \"\"\"\n",
    "    Visualizes a graphon approximated by a neural network.\n",
    "\n",
    "    The graphon W(x,y) is visualized as a heatmap. The function ensures\n",
    "    symmetry (W(x,y) = W(y,x)) and sets the diagonal W(x,x) = 0.\n",
    "    The colorbar is fixed between 0 and 1.\n",
    "\n",
    "    Args:\n",
    "        model (torch.nn.Module): The neural network model that takes a (N, 2) tensor\n",
    "                                 of (x,y) coordinates and outputs a (N, 1) tensor\n",
    "                                 of graphon values.\n",
    "        resolution (int, optional): The resolution of the grid for x and y.\n",
    "                                    Defaults to 100.\n",
    "        device (str, optional): The device to run the model on ('cpu', 'cuda', 'cuda:0', etc.).\n",
    "                                If None, it tries to use the device of the model's parameters\n",
    "                                or defaults to 'cpu'.\n",
    "        title (str, optional): The title for the plot. Defaults to\n",
    "                               \"Neural Network Graphon Visualization\".\n",
    "    \"\"\"\n",
    "    if device is None:\n",
    "        try:\n",
    "            # Attempt to infer device from model parameters\n",
    "            device = next(model.parameters()).device\n",
    "        except StopIteration:\n",
    "            # No parameters, or model is not a typical nn.Module, default to CPU\n",
    "            device = torch.device('cpu')\n",
    "            logger.info(\"Could not infer device from model, defaulting to CPU.\")\n",
    "        except AttributeError:\n",
    "            # Model might not have 'parameters' attribute (e.g. if it's a function)\n",
    "            device = torch.device('cpu')\n",
    "            logger.info(\"Model does not have parameters, defaulting to CPU.\")\n",
    "\n",
    "\n",
    "    model.eval()  # Set the model to evaluation mode\n",
    "    model.to(device)\n",
    "\n",
    "    # 1. Generate grid points\n",
    "    x_coords = np.linspace(0, 1, resolution)\n",
    "    y_coords = np.linspace(0, 1, resolution)\n",
    "    X, Y = np.meshgrid(x_coords, y_coords)\n",
    "\n",
    "    # 2. Prepare input for the neural network\n",
    "    # Shape: (resolution*resolution, 2)\n",
    "    grid_points = np.stack((X.flatten(), Y.flatten()), axis=1)\n",
    "    grid_points_tensor = torch.tensor(grid_points, dtype=torch.float32).to(device)\n",
    "\n",
    "    # 3. Get model output\n",
    "    with torch.no_grad():\n",
    "        W_net_flat = model(grid_points_tensor)\n",
    "\n",
    "    # Reshape to (resolution, resolution)\n",
    "    W_net = W_net_flat.cpu().numpy().reshape(resolution, resolution)\n",
    "\n",
    "    # 4. Ensure symmetry: W_sym(x,y) = (W_net(x,y) + W_net(y,x)) / 2\n",
    "    # This handles cases where the network itself might not be perfectly symmetric.\n",
    "    W_sym = (W_net + W_net.T) / 2.0\n",
    "\n",
    "    # 5. Set diagonal to 0 (typically, graphons don't have self-loops W(x,x)=0)\n",
    "    np.fill_diagonal(W_sym, 0)\n",
    "\n",
    "    # 6. Clip values to be within [0, 1] as probabilities,\n",
    "    #    though the fixed colorbar vmin/vmax will also handle this visually.\n",
    "    W_sym = np.clip(W_sym, 0, 1)\n",
    "\n",
    "    # 7. Plot the heatmap\n",
    "    plt.figure(figsize=(8, 6.5))\n",
    "    # Using origin='lower' makes (0,0) at the bottom-left corner, like typical plots.\n",
    "    #imshow_obj = plt.imshow(W_sym, extent=[0, 1, 0, 1], origin='lower', cmap='viridis', vmin=0, vmax=1)\n",
    "    \n",
    "    # Using origin='upper' makes (0,0) at the top-left corner, like a matrix.\n",
    "    # This was implied by how Z_sym was indexed in the original moment_graphon (e.g. Z_sym[i_upper]).\n",
    "    # Let's use 'upper' to be consistent with matrix indexing if that's the expectation.\n",
    "    # If the user expects (0,0) at bottom-left, then use origin='lower' and potentially flip Y coordinates\n",
    "    # during input generation or flip W_sym before plotting using W_sym = np.flipud(W_sym).\n",
    "    # For now, 'upper' with imshow means row 0 is at the top.\n",
    "    imshow_obj = plt.imshow(W_sym, extent=[0, 1, 1, 0], origin='upper', cmap='viridis', vmin=0, vmax=1)\n",
    "\n",
    "\n",
    "    plt.title(title)\n",
    "    plt.xlabel(\"x\")\n",
    "    plt.ylabel(\"y\")\n",
    "\n",
    "    # Add colorbar fixed between 0 and 1\n",
    "    cbar = plt.colorbar(imshow_obj, fraction=0.046, pad=0.04)\n",
    "    cbar.set_label(\"W(x,y) - Edge Probability\")\n",
    "    cbar.set_ticks(np.linspace(0, 1, 6)) # Example: 0, 0.2, 0.4, 0.6, 0.8, 1.0\n",
    "\n",
    "    plt.show()\n",
    "\n",
    "def train_Moment(moments, weights, model_name):\n",
    "    \"\"\"\n",
    "    Train the Moment model using the dataset and return GW loss, centrality NMSE averages, and standard deviations.\n",
    "\n",
    "    Args:\n",
    "        dataset_name (str): Name of the dataset file in the dataset folder.\n",
    "        graphon_idx (int): Index of the graphon.\n",
    "\n",
    "    Returns:\n",
    "        tuple: Avg GW loss, Std GW loss, Avg and Std NMSE for each centrality measure.\n",
    "    \"\"\"\n",
    "\n",
    "    # Default parameters\n",
    "    epochs = 2000\n",
    "    #patience = 600\n",
    "    #lr = 1e-3\n",
    "    #N = 30000\n",
    "    #hid_dim = 64\n",
    "    num_motifs = 9\n",
    "    weight_mode = 0\n",
    "    #lr = 9.26e-05\n",
    "    lr = 1e-4\n",
    "    N = 20000\n",
    "    hid_dim = 5*[96]\n",
    "    num_layers = 5\n",
    "    \n",
    "    patience = 1000\n",
    "    w0 = 9.\n",
    "\n",
    "  \n",
    "\n",
    "    Es = [[[(0, 1)]], [[(0, 1), (1, 2)], [(0, 1), (0, 2), (1, 2)]], [[(0, 1), (1, 2), (2, 3)], [(0, 1), (0, 2), (0, 3)], [(0, 1), (0, 2), (1, 3), (2, 3)], [(0, 1), (0, 2), (0, 3), (1, 2)],\n",
    "                                                                 [(0, 1), (0, 2), (1, 2), (1, 3), (2, 3)], [(0, 1), (0, 2), (0, 3), (1, 2), (1, 3), (2, 3)]],\n",
    "      [[(0, 1), (1, 2), (2, 3), (3, 4)], [(0, 1), (0, 2), (0, 3), (3, 4)], [(0, 1), (0, 2), (0, 3), (0, 4)], [(0, 1), (0, 2), (1, 2), (1, 3), (2, 4)], [(0, 1), (0, 2), (1, 2), (0, 3), (3, 4)],\n",
    "       [(0, 1), (0, 2), (0, 3), (0, 4), (3, 4)], [(0, 1), (0, 2), (1, 3), (2, 4), (3, 4)], [(0, 1), (0, 2), (0, 3), (2, 4), (3, 4)], [(0, 1), (0, 2), (0, 3), (2, 4), (3, 4), (0, 4)], [(0, 1), (0, 2), (0, 3), (0, 4), (1, 2), (3, 4)] ,\n",
    "       [(0, 1), (0, 2), (0, 3), (2, 4), (3, 4), (2, 3)], [(0, 1), (0, 2), (0, 3), (1, 4), (2, 4), (3, 4)], [(0, 1), (0, 2), (1, 2), (1, 3), (2, 4), (3, 4)], [(0, 1), (0, 2), (1, 2), (1, 3), (2, 3), (1, 4), (2, 4)],\n",
    "       [(0, 1), (0, 2), (0, 3), (1, 2), (1, 3), (2, 3), (2, 4)], [(0, 1), (0, 2), (0, 3), (0, 4), (1, 2), (2, 3), (3, 4)], [(0, 1), (0, 2), (0, 3), (1, 2), (2, 3), (1, 4), (3, 4)], [(0, 1), (0, 2), (0, 3), (1, 3), (1, 2), (2, 3), (2, 4), (3, 4)],\n",
    "       [(0, 1), (0, 2), (0, 3), (0, 4), (1, 2), (1, 3), (2, 4), (3, 4)], [(0, 1), (0, 2), (0, 3), (0, 4), (1, 2), (2, 3), (1, 3), (2, 4), (3, 4)], [(0, 1), (0, 2), (0, 3), (0, 4), (1, 2), (1, 3), (1, 4), (2, 3), (2, 4), (3, 4)]\n",
    "       ]]\n",
    "\n",
    "    induced_list = motifs_to_induced_motifs(Es)\n",
    "    device = torch.device('cuda:2' if torch.cuda.is_available() else 'cpu')\n",
    "\n",
    "    real_moments = torch.tensor(moments).to(device)\n",
    "\n",
    "    #real_moments[1] = (real_moments[0]**2)*(1-real_moments[0])\n",
    "    #eal_moments[2] = (real_moments[0]**3)\n",
    "    # print density counting is finished\n",
    "    #print(\"Density counting is finished\")\n",
    "    #print(real_moments)\n",
    "    while True:\n",
    "        \n",
    "        if model_name == 'SIREN':\n",
    "\n",
    "            lr = 9.26e-05\n",
    "            N = 30000\n",
    "            hid_dim = 6 * [96]\n",
    "            num_layers = 6\n",
    "            weight_mode = 0\n",
    "            patience = 700\n",
    "            w0 = 9.0\n",
    "            model = SirenNet(2, hid_dim, 1, num_layers=num_layers, w0=w0, w0_initial=30.).train().to(device)\n",
    "        elif model_name == 'MLP':\n",
    "            lr = 1e-4\n",
    "            hid_dim = 64\n",
    "            model = lipmlp([2, hid_dim, 1]).train().to(device)\n",
    "        losses = train_momentnet(model, induced_list[2:num_motifs], real_moments[2:num_motifs], 4, N, epochs, patience, lr, device, 1)\n",
    "\n",
    "        try:\n",
    "            if (losses[0] - losses[-1])/losses[0] > 1e-3:\n",
    "                break\n",
    "            else:\n",
    "                print((losses[0] - losses[-1])/losses[0])\n",
    "        except:\n",
    "            \n",
    "            continue\n",
    "\n",
    "\n",
    "    #visualize_nn_graphon(model)\n",
    "    return model\n",
    "\n",
    "\n",
    "# Set up logging\n",
    "logger = logging.getLogger()\n",
    "logger.setLevel(logging.DEBUG)\n",
    "formatter = logging.Formatter('%(asctime)s - %(levelname)s: - %(message)s', datefmt='%Y-%m-%d')\n",
    "if not logger.handlers:\n",
    "    # Optionally, log to the screen\n",
    "    ch = logging.StreamHandler()\n",
    "    ch.setLevel(logging.DEBUG)\n",
    "    ch.setFormatter(formatter)\n",
    "    logger.addHandler(ch)\n",
    "logging.getLogger('matplotlib').setLevel(logging.WARNING)\n",
    "\n",
    "\n",
    "########################################\n",
    "# Helper Function Definitions\n",
    "########################################\n",
    "\n",
    "def prepare_dataset_x(dataset):\n",
    "    logger.info(\"Prepare dataset x\")\n",
    "    if dataset[0].x is None:\n",
    "        logger.info(\"dataset[0].x is None\")\n",
    "        max_degree = 0\n",
    "        degs_all = []\n",
    "        for data in dataset:\n",
    "            d = degree(data.edge_index[0], dtype=torch.long)\n",
    "            degs_all.append(d)\n",
    "            max_degree = max(max_degree, d.max().item())\n",
    "            data.num_nodes = int(torch.max(data.edge_index)) + 1\n",
    "\n",
    "        if max_degree < 2000:\n",
    "            for data in dataset:\n",
    "                degs = degree(data.edge_index[0], dtype=torch.long)\n",
    "                data.x = F.one_hot(degs, num_classes=max_degree+1).to(torch.float)\n",
    "        else:\n",
    "            deg = torch.cat(degs_all, dim=0).to(torch.float)\n",
    "            mean, std = deg.mean().item(), deg.std().item()\n",
    "            for data in dataset:\n",
    "                degs = degree(data.edge_index[0], dtype=torch.long)\n",
    "                data.x = ((degs - mean) / std).view(-1, 1)\n",
    "    return dataset\n",
    "\n",
    "\n",
    "def prepare_dataset_onehot_y(dataset):\n",
    "    y_set = set()\n",
    "    for data in dataset:\n",
    "        y_set.add(int(data.y))\n",
    "    num_classes = len(y_set)\n",
    "    for data in dataset:\n",
    "        # Each data.y becomes a one-hot vector; here we take the first element since \n",
    "        # F.one_hot returns a vector wrapped in an extra dimension.\n",
    "        data.y = F.one_hot(data.y, num_classes=num_classes).to(torch.float)[0]\n",
    "    return dataset\n",
    "\n",
    "\n",
    "def mixup_cross_entropy_loss(input, target, size_average=True):\n",
    "    \"\"\"Mixup version of cross entropy loss.\"\"\"\n",
    "    assert input.size() == target.size()\n",
    "    assert isinstance(input, Variable) and isinstance(target, Variable)\n",
    "    loss = - torch.sum(input * target)\n",
    "    return loss / input.size()[0] if size_average else loss\n",
    "\n",
    "\n",
    "def train(model, train_loader, optimizer, device, num_classes):\n",
    "    model.train()\n",
    "    loss_all = 0\n",
    "    graph_all = 0\n",
    "    for data in train_loader:\n",
    "        data = data.to(device)\n",
    "        optimizer.zero_grad()\n",
    "        output = model(data.x, data.edge_index, data.batch)\n",
    "        y = data.y.view(-1, num_classes)\n",
    "        loss = mixup_cross_entropy_loss(output, y)\n",
    "        loss.backward()\n",
    "        loss_all += loss.item() * data.num_graphs\n",
    "        #print(f\"loss : {loss_all}\")\n",
    "        graph_all += data.num_graphs\n",
    "        optimizer.step()\n",
    "    #print(f\"Output : {output}\")\n",
    "    return loss_all / graph_all\n",
    "\n",
    "\n",
    "def test(model, loader, device, num_classes):\n",
    "    model.eval()\n",
    "    correct = 0\n",
    "    total = 0\n",
    "    loss_all = 0\n",
    "    for data in loader:\n",
    "        data = data.to(device)\n",
    "        output = model(data.x, data.edge_index, data.batch)\n",
    "        pred = output.max(dim=1)[1]\n",
    "        y = data.y.view(-1, num_classes)\n",
    "        loss_all += mixup_cross_entropy_loss(output, y).item() * data.num_graphs\n",
    "        \n",
    "        # Convert one-hot to class indices for comparison\n",
    "        y_labels = y.max(dim=1)[1]\n",
    "        correct += pred.eq(y_labels).sum().item()\n",
    "        total += data.num_graphs\n",
    "    return correct / total, loss_all / total\n",
    "\n",
    "\n",
    "\n",
    "def two_moments_mixup(two_moments, la=0.5, num_sample=20, ge='ISGL', resolution=None):\n",
    "\n",
    "    label = la * two_moments[0][0] + (1 - la) * two_moments[1][0]\n",
    "    sample_graph_label = torch.from_numpy(label).type(torch.float32)\n",
    "\n",
    "    \n",
    "    # Res = [int(resolution[0]) for _ in range(num_sample)]\n",
    "    #print(Res)\n",
    "\n",
    "    #print(two_moments[0][1])\n",
    "    #print(two_moments[1][1])\n",
    "    new_moment = la * two_moments[0][1] + (1 - la) * two_moments[1][1]\n",
    "    \n",
    "    #sample_graph = (np.random.rand(*new_graphon.shape) <= new_graphon).astype(np.int32)\n",
    "\n",
    "    trained_model = train_Moment(new_moment, 0, \"MLP\")\n",
    "    graphon = moment_graphon(resolution, trained_model)\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "    sample_graphs = []\n",
    "    for i in range(num_sample):\n",
    "        \n",
    "        \n",
    "        done = 0\n",
    "        while True:\n",
    "            if done > 20:\n",
    "                return []\n",
    "            num_nodes = random.randint(30, 60)\n",
    "            sample_graph = simulate_graphs(graphon, seed_gsize =19, seed_edge =19, num_graphs = 1, num_nodes = num_nodes, graph_size = 'fixed', offset =0)[0]\n",
    "\n",
    "            # print number of nodes\n",
    "            #print(f\"Number of nodes: {num_nodes}\")\n",
    "\n",
    "            #sample_graph = np.triu(sample_graph)\n",
    "            #sample_graph = sample_graph + sample_graph.T - np.diag(np.diag(sample_graph))\n",
    "            sample_graph = sample_graph[sample_graph.sum(axis=1) != 0]\n",
    "            sample_graph = sample_graph[:, sample_graph.sum(axis=0) != 0]\n",
    "\n",
    "            A = torch.from_numpy(sample_graph)\n",
    "            edge_index, _ = dense_to_sparse(A)\n",
    "            num_nodes = sample_graph.shape[0]\n",
    "\n",
    "            if num_nodes == 0:\n",
    "                print('num_nodes is 0')\n",
    "                done += 1\n",
    "                continue\n",
    "\n",
    "            # continue if the graph is empty in terms of edges\n",
    "            if edge_index.shape[1] == 0:\n",
    "                print('edge_index is 0')\n",
    "                done += 1\n",
    "                continue\n",
    "\n",
    "            break\n",
    "\n",
    "\n",
    "        pyg_graph = Data()\n",
    "        pyg_graph.y = sample_graph_label\n",
    "        pyg_graph.edge_index = edge_index\n",
    "        pyg_graph.num_nodes = num_nodes\n",
    "        sample_graphs.append(pyg_graph)\n",
    "        \n",
    "    return sample_graphs\n",
    "\n",
    "########################################\n",
    "# Default Parameters and Seed List Setup\n",
    "########################################\n",
    "\n",
    "# Default parameters\n",
    "data_path      = \"./\"\n",
    "dataset_name   = \"IMDB-MULTI\"\n",
    "model_name     = \"GIN\"\n",
    "num_epochs     = 400\n",
    "batch_size     = 128\n",
    "learning_rate  = 0.001\n",
    "num_hidden     = 64\n",
    "lam_range      = [0.1, 0.4]\n",
    "aug_ratio      = 0.2\n",
    "aug_num        = 10\n",
    "ge             = \"Moment\"   # options: \"ISGL\", \"Moment\", \"IGNR\", etc.\n",
    "log_screen     = True\n",
    "gmixup         = True\n",
    "n_epochs_inr   = 20\n",
    "\n",
    "Es = [[[(0, 1)]], [[(0, 1), (1, 2)], [(0, 1), (0, 2), (1, 2)]], [[(0, 1), (1, 2), (2, 3)], [(0, 1), (0, 2), (0, 3)], [(0, 1), (0, 2), (1, 3), (2, 3)], [(0, 1), (0, 2), (0, 3), (1, 2)],\n",
    "                                                                 [(0, 1), (0, 2), (1, 2), (1, 3), (2, 3)], [(0, 1), (0, 2), (0, 3), (1, 2), (1, 3), (2, 3)]],\n",
    "      [[(0, 1), (1, 2), (2, 3), (3, 4)], [(0, 1), (0, 2), (0, 3), (3, 4)], [(0, 1), (0, 2), (0, 3), (0, 4)], [(0, 1), (0, 2), (1, 2), (1, 3), (2, 4)], [(0, 1), (0, 2), (1, 2), (0, 3), (3, 4)],\n",
    "       [(0, 1), (0, 2), (0, 3), (0, 4), (3, 4)], [(0, 1), (0, 2), (1, 3), (2, 4), (3, 4)], [(0, 1), (0, 2), (0, 3), (2, 4), (3, 4)], [(0, 1), (0, 2), (0, 3), (2, 4), (3, 4), (0, 4)], [(0, 1), (0, 2), (0, 3), (0, 4), (1, 2), (3, 4)] ,\n",
    "       [(0, 1), (0, 2), (0, 3), (2, 4), (3, 4), (2, 3)], [(0, 1), (0, 2), (0, 3), (1, 4), (2, 4), (3, 4)], [(0, 1), (0, 2), (1, 2), (1, 3), (2, 4), (3, 4)], [(0, 1), (0, 2), (1, 2), (1, 3), (2, 3), (1, 4), (2, 4)],\n",
    "       [(0, 1), (0, 2), (0, 3), (1, 2), (1, 3), (2, 3), (2, 4)], [(0, 1), (0, 2), (0, 3), (0, 4), (1, 2), (2, 3), (3, 4)], [(0, 1), (0, 2), (0, 3), (1, 2), (2, 3), (1, 4), (3, 4)], [(0, 1), (0, 2), (0, 3), (1, 3), (1, 2), (2, 3), (2, 4), (3, 4)],\n",
    "       [(0, 1), (0, 2), (0, 3), (0, 4), (1, 2), (1, 3), (2, 4), (3, 4)], [(0, 1), (0, 2), (0, 3), (0, 4), (1, 2), (2, 3), (1, 3), (2, 4), (3, 4)], [(0, 1), (0, 2), (0, 3), (0, 4), (1, 2), (1, 3), (1, 4), (2, 3), (2, 4), (3, 4)]\n",
    "       ]]\n",
    "\n",
    "induced_list = motifs_to_induced_motifs(Es)\n",
    "\n",
    "# Example seed list\n",
    "\n",
    "# To store metrics across seeds\n",
    "all_best_test_acc = []\n",
    "all_last10_avg_acc = []\n",
    "\n",
    "def plot_moment_vectors(moment_matrix, title=\"Moment Vectors\"):\n",
    "    \"\"\"\n",
    "    Plots moment vectors from a list of [label, moments_vector] pairs.\n",
    "\n",
    "    Args:\n",
    "        moment_matrix (list): A list where each element is a list or tuple\n",
    "                              of the form [label (str), real_moments (list or np.array)].\n",
    "        title (str): The title for the plot.\n",
    "    \"\"\"\n",
    "    if not moment_matrix:\n",
    "        print(\"The moment_matrix is empty. Nothing to plot.\")\n",
    "        return\n",
    "\n",
    "    plt.figure(figsize=(10, 6)) # Adjust figure size as needed\n",
    "\n",
    "    for item in moment_matrix:\n",
    "        if len(item) != 2:\n",
    "            print(f\"Skipping invalid item: {item}. Expected [label, real_moments].\")\n",
    "            continue\n",
    "\n",
    "        label, real_moments = item\n",
    "        \n",
    "        # Ensure real_moments is a numpy array for easier handling\n",
    "        moments_vector = np.array(real_moments)\n",
    "\n",
    "        if moments_vector.ndim != 1:\n",
    "            print(f\"Skipping label '{label}': moments_vector is not 1-dimensional.\")\n",
    "            continue\n",
    "        \n",
    "        if moments_vector.size == 0:\n",
    "            print(f\"Skipping label '{label}': moments_vector is empty.\")\n",
    "            continue\n",
    "\n",
    "        # Create an x-axis based on the index of the moments\n",
    "        x_values = np.arange(len(moments_vector))\n",
    "        \n",
    "        plt.plot(x_values, moments_vector, marker='o', linestyle='-', label=str(label))\n",
    "\n",
    "    plt.title(title)\n",
    "    plt.xlabel(\"Moment Index\")\n",
    "    plt.ylabel(\"Moment Value\")\n",
    "    \n",
    "    # Add a legend if there are any lines plotted\n",
    "    # Get current handles and labels\n",
    "    handles, labels = plt.gca().get_legend_handles_labels()\n",
    "    if handles: # Check if any lines were actually plotted\n",
    "        plt.legend()\n",
    "    else:\n",
    "        print(\"No valid data was plotted, so no legend will be shown.\")\n",
    "        \n",
    "    plt.grid(True)\n",
    "    plt.tight_layout() # Adjusts plot to prevent labels from overlapping\n",
    "    plt.show()\n",
    "\n",
    "########################################\n",
    "# Main Loop: Iterate over Seed List\n",
    "########################################\n",
    "seeds = [61314, 1314, 11314, 21314, 31314, 41314, 51314, 71314]\n",
    "\n",
    "for seed in seeds:\n",
    "    for seed_torch in [17]:\n",
    "        logger.info(\"==========================================\")\n",
    "        logger.info(f\"Starting training for seed {seed}\")\n",
    "\n",
    "        if seed == 61314:\n",
    "            seed_torch = 4\n",
    "\n",
    "        # Set seeds for reproducibility\n",
    "        #torch.manual_seed(seed)\n",
    "        np.random.seed(seed)\n",
    "        random.seed(seed)\n",
    "\n",
    "        device = torch.device('cuda:2' if torch.cuda.is_available() else 'cpu')\n",
    "        logger.info(f\"Running on device: {device}\")\n",
    "\n",
    "        # Load dataset\n",
    "        path = osp.join(data_path, dataset_name)\n",
    "        dataset = TUDataset(path, name=dataset_name)\n",
    "        dataset = list(dataset)\n",
    "        \n",
    "        # Reshape labels\n",
    "        for graph in dataset:\n",
    "            graph.y = graph.y.view(-1)\n",
    "        dataset = prepare_dataset_onehot_y(dataset)\n",
    "\n",
    "        # Log global dataset statistics\n",
    "        avg_num_nodes, max_num_nodes, avg_num_edges, avg_density, median_num_nodes, median_num_edges, median_density, min_num_nodes, std_num_nodes = stat_graph(dataset)\n",
    "        logger.info(f\"Dataset {dataset_name}: {len(dataset)} graphs, avg nodes: {avg_num_nodes}\")\n",
    "\n",
    "        # Shuffle and split dataset: 70% train, 10% validation (from train), 20% test\n",
    "        random.shuffle(dataset)\n",
    "        train_nums = int(len(dataset) * 0.7)\n",
    "        train_val_nums = int(len(dataset) * 0.8)\n",
    "        \n",
    "        # Log training subset statistics\n",
    "        avg_num_nodes, max_num_nodes, avg_num_edges, avg_density, median_num_nodes, median_num_edges, median_density, min_num_nodes, std_num_nodes = stat_graph(dataset[:train_nums])\n",
    "        logger.info(f\"Training graphs: avg nodes: {avg_num_nodes}, max nodes: {max_num_nodes}\")\n",
    "\n",
    "        resolution = 10*int(median_num_nodes)\n",
    "        logger.info(f\"Resolution set to: {resolution}\")\n",
    "        \n",
    "        # Augment dataset using graphon mixup if enabled\n",
    "        if gmixup:\n",
    "            class_graphs = split_class_graphs(dataset[:train_nums])\n",
    "            test_class_graphs = split_class_graphs(dataset[train_val_nums:])\n",
    "            ###########################################\n",
    "            def compute_moment_scores(moment_matrix):\n",
    "                \"\"\"\n",
    "                Computes discrimination scores for each moment (column) in a matrix.\n",
    "                \n",
    "                Parameters:\n",
    "                    moment_matrix (np.ndarray): A 2D NumPy array of shape (n_classes, m_moments)\n",
    "                                                where each row corresponds to a class.\n",
    "                \n",
    "                Returns:\n",
    "                    std_scores (np.ndarray): Standard deviation scores for each moment.\n",
    "                    range_scores (np.ndarray): Range (max-min) scores for each moment.\n",
    "                    avg_pairwise (np.ndarray): Average pairwise difference scores for each moment.\n",
    "                \"\"\"\n",
    "                # Method 1: Standard deviation across classes for each moment.\n",
    "                std_scores = np.std(moment_matrix, axis=0)\n",
    "                \n",
    "                # Method 2: Range (maximum minus minimum) for each moment.\n",
    "                range_scores = np.ptp(moment_matrix, axis=0)  # np.ptp returns the range (peak to peak)\n",
    "                \n",
    "                # Method 3: Average pairwise absolute differences between classes for each moment.\n",
    "                n, m = moment_matrix.shape\n",
    "                avg_pairwise = np.zeros(m)\n",
    "                for j in range(m):\n",
    "                    pairwise_diffs = []\n",
    "                    for i in range(n):\n",
    "                        for k in range(i + 1, n):\n",
    "                            diff = abs(moment_matrix[i, j] - moment_matrix[k, j])\n",
    "                            pairwise_diffs.append(diff)\n",
    "                    avg_pairwise[j] = np.mean(pairwise_diffs)\n",
    "                    \n",
    "                return std_scores, range_scores, avg_pairwise\n",
    "\n",
    "\n",
    "\n",
    "            Es = [[[(0, 1)]], [[(0, 1), (1, 2)], [(0, 1), (0, 2), (1, 2)]], [[(0, 1), (1, 2), (2, 3)], [(0, 1), (0, 2), (0, 3)], [(0, 1), (0, 2), (1, 3), (2, 3)], [(0, 1), (0, 2), (0, 3), (1, 2)],\n",
    "                                                                            [(0, 1), (0, 2), (1, 2), (1, 3), (2, 3)], [(0, 1), (0, 2), (0, 3), (1, 2), (1, 3), (2, 3)]],\n",
    "                [[(0, 1), (1, 2), (2, 3), (3, 4)], [(0, 1), (0, 2), (0, 3), (3, 4)], [(0, 1), (0, 2), (0, 3), (0, 4)], [(0, 1), (0, 2), (1, 2), (1, 3), (2, 4)], [(0, 1), (0, 2), (1, 2), (0, 3), (3, 4)],\n",
    "                [(0, 1), (0, 2), (0, 3), (0, 4), (3, 4)], [(0, 1), (0, 2), (1, 3), (2, 4), (3, 4)], [(0, 1), (0, 2), (0, 3), (2, 4), (3, 4)], [(0, 1), (0, 2), (0, 3), (2, 4), (3, 4), (0, 4)], [(0, 1), (0, 2), (0, 3), (0, 4), (1, 2), (3, 4)] ,\n",
    "                [(0, 1), (0, 2), (0, 3), (2, 4), (3, 4), (2, 3)], [(0, 1), (0, 2), (0, 3), (1, 4), (2, 4), (3, 4)], [(0, 1), (0, 2), (1, 2), (1, 3), (2, 4), (3, 4)], [(0, 1), (0, 2), (1, 2), (1, 3), (2, 3), (1, 4), (2, 4)],\n",
    "                [(0, 1), (0, 2), (0, 3), (1, 2), (1, 3), (2, 3), (2, 4)], [(0, 1), (0, 2), (0, 3), (0, 4), (1, 2), (2, 3), (3, 4)], [(0, 1), (0, 2), (0, 3), (1, 2), (2, 3), (1, 4), (3, 4)], [(0, 1), (0, 2), (0, 3), (1, 3), (1, 2), (2, 3), (2, 4), (3, 4)],\n",
    "                [(0, 1), (0, 2), (0, 3), (0, 4), (1, 2), (1, 3), (2, 4), (3, 4)], [(0, 1), (0, 2), (0, 3), (0, 4), (1, 2), (2, 3), (1, 3), (2, 4), (3, 4)], [(0, 1), (0, 2), (0, 3), (0, 4), (1, 2), (1, 3), (1, 4), (2, 3), (2, 4), (3, 4)]\n",
    "                ]]\n",
    "\n",
    "            induced_list = motifs_to_induced_motifs(Es)\n",
    "            \n",
    "            torch.manual_seed(seed_torch)\n",
    "            torch.cuda.manual_seed(seed_torch)\n",
    "            torch.cuda.manual_seed_all(seed_torch)\n",
    "            torch.backends.cudnn.deterministic = True\n",
    "            torch.backends.cudnn.benchmark = False\n",
    "            os.environ[\"CUBLAS_WORKSPACE_CONFIG\"] = \":4096:8\"\n",
    "            torch.use_deterministic_algorithms(True)\n",
    "            num_sample = int(train_nums * aug_ratio / aug_num)\n",
    "            lam_list = np.random.uniform(low=lam_range[0], high=lam_range[1], size=(aug_num,))\n",
    "            new_graph = []\n",
    "            #lam_list = [0.18]\n",
    "            for lam in lam_list:\n",
    "                logger.info(f\"lam: {lam}\")\n",
    "                logger.info(f\"num_sample: {num_sample}\")\n",
    "\n",
    "\n",
    "                class_graphs = split_class_graphs(dataset[:train_nums])\n",
    "                # pick 10 graphs from each class\n",
    "                class_graphs = [(label, random.sample(graphs, min(10, len(graphs)))) for label, graphs in class_graphs]\n",
    "                moment_matrix = []\n",
    "                for label, graphs in class_graphs:\n",
    "                    estimated_densities = np.zeros(9)\n",
    "                    nx_graphs = [nx.from_numpy_array(graph) for graph in graphs]\n",
    "                    for graph in nx_graphs:\n",
    "                        # only if number of nodes is greater than 6\n",
    "                        if graph.number_of_nodes() < 6:\n",
    "                            continue\n",
    "\n",
    "                        #if graph.number_of_nodes() < resolution:\n",
    "                        #    continue\n",
    "                        node_orbit_counts = orca(graph)\n",
    "                        density = count2density(node_orbit_counts, graph.number_of_nodes())\n",
    "                        estimated_densities += density\n",
    "                    #print(f\"Estimated densities for label {label}: {estimated_densities}\")\n",
    "                    real_moments = estimated_densities / len(graphs)\n",
    "                    moment_matrix.append([label, real_moments])\n",
    "                \n",
    "                \n",
    " \n",
    "                all_moments = sorted(moment_matrix, key=lambda x: x[1][0])\n",
    "                two_moments = [all_moments[0], all_moments[-1]]\n",
    "\n",
    "                \n",
    "                new_graph += two_moments_mixup(two_moments, la=lam, num_sample=num_sample, ge=\"USVT\", resolution=resolution)\n",
    "                \n",
    "            avg_num_nodes, max_num_nodes, avg_num_edges, avg_density, median_num_nodes, median_num_edges, median_density, min_num_nodes, std_num_nodes = stat_graph(new_graph)\n",
    "            logger.info(f\"New graphs: avg nodes: {avg_num_nodes}, min nodes: {min_num_nodes}, max nodes: {max_num_nodes}\")\n",
    "            logger.info(f\"Real augmentation ratio: {len(new_graph)/train_nums}\")\n",
    "            dataset = new_graph + dataset\n",
    "            train_nums += len(new_graph)\n",
    "            train_val_nums += len(new_graph)\n",
    "\n",
    "        # Prepare node features for whole dataset\n",
    "        dataset = prepare_dataset_x(dataset)\n",
    "        logger.info(f\"Dataset feature shape: {dataset[0].x.shape}\")\n",
    "        logger.info(f\"Dataset label shape: {dataset[0].y.shape}\")\n",
    "\n",
    "        num_features = dataset[0].x.shape[1]\n",
    "        num_classes = dataset[0].y.shape[0]\n",
    "\n",
    "        for data in dataset:\n",
    "            if not hasattr(data, 'edge_attr') or data.edge_attr is None:\n",
    "                data.edge_attr = torch.ones((data.edge_index.size(1), 3)) \n",
    "\n",
    "        # Split dataset into train, validation and test\n",
    "        train_dataset = dataset[:train_nums]\n",
    "        random.shuffle(train_dataset)\n",
    "        val_dataset = dataset[train_nums:train_val_nums]\n",
    "        test_dataset = dataset[train_val_nums:]\n",
    "        \n",
    "        logger.info(f\"Train dataset size: {len(train_dataset)}\")\n",
    "        logger.info(f\"Validation dataset size: {len(val_dataset)}\")\n",
    "        logger.info(f\"Test dataset size: {len(test_dataset)}\")\n",
    "\n",
    "        train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)\n",
    "        val_loader = DataLoader(val_dataset, batch_size=batch_size)\n",
    "        test_loader = DataLoader(test_dataset, batch_size=batch_size)\n",
    "\n",
    "        # Instantiate the model\n",
    "        if model_name == \"GIN\":\n",
    "            model_instance = GIN(num_features=num_features, num_classes=num_classes, num_hidden=num_hidden).to(device)\n",
    "        else:\n",
    "            logger.info(\"No valid model specified.\")\n",
    "            continue\n",
    "\n",
    "        #optimizer = torch.optim.Adam(model_instance.parameters(), lr=learning_rate, weight_decay=5e-4)\n",
    "        optimizer = schedulefree.AdamWScheduleFree(model_instance.parameters(), lr = learning_rate)\n",
    "        optimizer.train()\n",
    "        scheduler = StepLR(optimizer, step_size=100, gamma=0.5)\n",
    "\n",
    "\n",
    "\n",
    "        # Train the model\n",
    "        last_10_epoch_acc = []  # To record test acc for the last 10 epochs\n",
    "        best_val_acc = 0\n",
    "        best_test_acc = 0\n",
    "\n",
    "        for epoch in range(1, num_epochs):\n",
    "        \n",
    "            train_loss = train(model_instance, train_loader, optimizer, device, num_classes)\n",
    "\n",
    "            val_acc, val_loss = test(model_instance, val_loader, device, num_classes)\n",
    "            test_acc, test_loss = test(model_instance, test_loader, device, num_classes)\n",
    "            scheduler.step()\n",
    "\n",
    "            if val_acc >= best_val_acc and epoch > 10:\n",
    "                best_val_acc = val_acc\n",
    "                best_test_acc = test_acc\n",
    "                best_epoch = epoch\n",
    "\n",
    "            if epoch > num_epochs - 10:\n",
    "                last_10_epoch_acc.append(test_acc)\n",
    "\n",
    "            if epoch % 30 == 0:\n",
    "                print('Epoch: {:03d}, Train Loss: {:.6f}, Val Loss: {:.6f}, Test Loss: {:.6f}, Val Acc: {:.6f}, Test Acc: {:.6f}'.format(\n",
    "                    epoch, train_loss, val_loss, test_loss, val_acc, test_acc))\n",
    "        \n",
    "        avg_last10 = np.mean(last_10_epoch_acc)\n",
    "\n",
    "        print(\"####################################################\")\n",
    "        print(f\"Seed {seed_torch}: Last 10 epochs average test acc: {np.round(avg_last10,3)}\")\n",
    "        print(f\"Seed {seed_torch}: Best test acc (based on validation) = {best_test_acc} at epoch {best_epoch}\")\n",
    "        print(\"####################################################\")\n",
    "\n",
    "    all_best_test_acc.append(best_test_acc)\n",
    "    all_last10_avg_acc.append(avg_last10)\n",
    "\n",
    "print(\"==========================================\")\n",
    "print(\"Overall Results Across Seeds:\")\n",
    "print(f\"Best Test Accuracy: Mean = {np.mean(all_best_test_acc):.3f}, Std = {np.std(all_best_test_acc):.3f}\")\n",
    "print(f\"Last 10 Epoch Test Accuracy Averag e: Mean = {np.mean(all_last10_avg_acc):.3f}, Std = {np.std(all_last10_avg_acc):.3f}\")\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6f9728c3",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "base",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
