{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "375c02db",
   "metadata": {},
   "outputs": [],
   "source": [
    "########################################\n",
    "# Single Cell Code for Multi-Seed Training\n",
    "########################################\n",
    "\n",
    "import sys\n",
    "import os \n",
    "sys.path.append('../../')\n",
    "sys.path.append('../')\n",
    "from time import time\n",
    "import logging\n",
    "import os.path as osp\n",
    "import numpy as np\n",
    "import torch\n",
    "import torch.nn.functional as F\n",
    "from torch_geometric.datasets import TUDataset\n",
    "from torch_geometric.data import DataLoader\n",
    "from torch_geometric.utils import degree\n",
    "from torch.autograd import Variable\n",
    "import random\n",
    "from torch.optim.lr_scheduler import StepLR\n",
    "\n",
    "# Import your utilities and models\n",
    "from gmixupUtils import stat_graph, split_class_graphs, align_graphs\n",
    "from gmixupUtils import two_graphons_mixup, universal_svd, get_graphon\n",
    "from gmixupUtils import GIN\n",
    "from Moment.tools import motifs_to_induced_motifs, orca, count2density\n",
    "from SIGL.tools import *\n",
    "from Moment.trainMoment import train_Moment\n",
    "import networkx as nx\n",
    "\n",
    "# Set up logging\n",
    "logger = logging.getLogger()\n",
    "logger.setLevel(logging.DEBUG)\n",
    "formatter = logging.Formatter('%(asctime)s - %(levelname)s: - %(message)s', datefmt='%Y-%m-%d')\n",
    "if not logger.handlers:\n",
    "    # Optionally, log to the screen\n",
    "    ch = logging.StreamHandler()\n",
    "    ch.setLevel(logging.DEBUG)\n",
    "    ch.setFormatter(formatter)\n",
    "    logger.addHandler(ch)\n",
    "logging.getLogger('matplotlib').setLevel(logging.WARNING)\n",
    "\n",
    "\n",
    "########################################\n",
    "# Helper Function Definitions\n",
    "########################################\n",
    "\n",
    "def prepare_dataset_x(dataset):\n",
    "    logger.info(\"Prepare dataset x\")\n",
    "    if dataset[0].x is None:\n",
    "        logger.info(\"dataset[0].x is None\")\n",
    "        max_degree = 0\n",
    "        degs_all = []\n",
    "        for data in dataset:\n",
    "            d = degree(data.edge_index[0], dtype=torch.long)\n",
    "            degs_all.append(d)\n",
    "            max_degree = max(max_degree, d.max().item())\n",
    "            data.num_nodes = int(torch.max(data.edge_index)) + 1\n",
    "\n",
    "        if max_degree < 2000:\n",
    "            for data in dataset:\n",
    "                degs = degree(data.edge_index[0], dtype=torch.long)\n",
    "                data.x = F.one_hot(degs, num_classes=max_degree+1).to(torch.float)\n",
    "        else:\n",
    "            deg = torch.cat(degs_all, dim=0).to(torch.float)\n",
    "            mean, std = deg.mean().item(), deg.std().item()\n",
    "            for data in dataset:\n",
    "                degs = degree(data.edge_index[0], dtype=torch.long)\n",
    "                data.x = ((degs - mean) / std).view(-1, 1)\n",
    "    return dataset\n",
    "\n",
    "\n",
    "def prepare_dataset_onehot_y(dataset):\n",
    "    y_set = set()\n",
    "    for data in dataset:\n",
    "        y_set.add(int(data.y))\n",
    "    num_classes = len(y_set)\n",
    "    for data in dataset:\n",
    "        # Each data.y becomes a one-hot vector; here we take the first element since \n",
    "        # F.one_hot returns a vector wrapped in an extra dimension.\n",
    "        data.y = F.one_hot(data.y, num_classes=num_classes).to(torch.float)[0]\n",
    "    return dataset\n",
    "\n",
    "\n",
    "def mixup_cross_entropy_loss(input, target, size_average=True):\n",
    "    \"\"\"Mixup version of cross entropy loss.\"\"\"\n",
    "    assert input.size() == target.size()\n",
    "    assert isinstance(input, Variable) and isinstance(target, Variable)\n",
    "    loss = - torch.sum(input * target)\n",
    "    return loss / input.size()[0] if size_average else loss\n",
    "\n",
    "\n",
    "def train(model, train_loader, optimizer, device, num_classes):\n",
    "    model.train()\n",
    "    loss_all = 0\n",
    "    graph_all = 0\n",
    "    for data in train_loader:\n",
    "        data = data.to(device)\n",
    "        optimizer.zero_grad()\n",
    "        output = model(data.x, data.edge_index, data.batch)\n",
    "        y = data.y.view(-1, num_classes)\n",
    "        loss = mixup_cross_entropy_loss(output, y)\n",
    "        loss.backward()\n",
    "        loss_all += loss.item() * data.num_graphs\n",
    "        graph_all += data.num_graphs\n",
    "        optimizer.step()\n",
    "    return loss_all / graph_all\n",
    "\n",
    "\n",
    "def test(model, loader, device, num_classes):\n",
    "    model.eval()\n",
    "    correct = 0\n",
    "    total = 0\n",
    "    loss_all = 0\n",
    "    for data in loader:\n",
    "        data = data.to(device)\n",
    "        output = model(data.x, data.edge_index, data.batch)\n",
    "        pred = output.max(dim=1)[1]\n",
    "        y = data.y.view(-1, num_classes)\n",
    "        loss_all += mixup_cross_entropy_loss(output, y).item() * data.num_graphs\n",
    "        # Convert one-hot to class indices for comparison\n",
    "        y_labels = y.max(dim=1)[1]\n",
    "        correct += pred.eq(y_labels).sum().item()\n",
    "        total += data.num_graphs\n",
    "    return correct / total, loss_all / total\n",
    "\n",
    "\n",
    "########################################\n",
    "# Default Parameters and Seed List Setup\n",
    "########################################\n",
    "\n",
    "# Default parameters\n",
    "data_path      = \"./\"\n",
    "dataset_name   = \"AIDS\"\n",
    "model_name     = \"GIN\"\n",
    "num_epochs     = 400\n",
    "batch_size     = 128\n",
    "learning_rate  = 0.01\n",
    "num_hidden     = 64\n",
    "lam_range      = [0.1, 0.2]\n",
    "aug_ratio      = 0.2\n",
    "aug_num        = 10\n",
    "ge             = \"USVT\"   # options: \"ISGL\", \"Moment\", \"IGNR\", etc.\n",
    "log_screen     = True\n",
    "gmixup         = True\n",
    "n_epochs_inr   = 20\n",
    "\n",
    "Es = [[[(0, 1)]], [[(0, 1), (1, 2)], [(0, 1), (0, 2), (1, 2)]], [[(0, 1), (1, 2), (2, 3)], [(0, 1), (0, 2), (0, 3)], [(0, 1), (0, 2), (1, 3), (2, 3)], [(0, 1), (0, 2), (0, 3), (1, 2)],\n",
    "                                                                 [(0, 1), (0, 2), (1, 2), (1, 3), (2, 3)], [(0, 1), (0, 2), (0, 3), (1, 2), (1, 3), (2, 3)]],\n",
    "      [[(0, 1), (1, 2), (2, 3), (3, 4)], [(0, 1), (0, 2), (0, 3), (3, 4)], [(0, 1), (0, 2), (0, 3), (0, 4)], [(0, 1), (0, 2), (1, 2), (1, 3), (2, 4)], [(0, 1), (0, 2), (1, 2), (0, 3), (3, 4)],\n",
    "       [(0, 1), (0, 2), (0, 3), (0, 4), (3, 4)], [(0, 1), (0, 2), (1, 3), (2, 4), (3, 4)], [(0, 1), (0, 2), (0, 3), (2, 4), (3, 4)], [(0, 1), (0, 2), (0, 3), (2, 4), (3, 4), (0, 4)], [(0, 1), (0, 2), (0, 3), (0, 4), (1, 2), (3, 4)] ,\n",
    "       [(0, 1), (0, 2), (0, 3), (2, 4), (3, 4), (2, 3)], [(0, 1), (0, 2), (0, 3), (1, 4), (2, 4), (3, 4)], [(0, 1), (0, 2), (1, 2), (1, 3), (2, 4), (3, 4)], [(0, 1), (0, 2), (1, 2), (1, 3), (2, 3), (1, 4), (2, 4)],\n",
    "       [(0, 1), (0, 2), (0, 3), (1, 2), (1, 3), (2, 3), (2, 4)], [(0, 1), (0, 2), (0, 3), (0, 4), (1, 2), (2, 3), (3, 4)], [(0, 1), (0, 2), (0, 3), (1, 2), (2, 3), (1, 4), (3, 4)], [(0, 1), (0, 2), (0, 3), (1, 3), (1, 2), (2, 3), (2, 4), (3, 4)],\n",
    "       [(0, 1), (0, 2), (0, 3), (0, 4), (1, 2), (1, 3), (2, 4), (3, 4)], [(0, 1), (0, 2), (0, 3), (0, 4), (1, 2), (2, 3), (1, 3), (2, 4), (3, 4)], [(0, 1), (0, 2), (0, 3), (0, 4), (1, 2), (1, 3), (1, 4), (2, 3), (2, 4), (3, 4)]\n",
    "       ]]\n",
    "\n",
    "induced_list = motifs_to_induced_motifs(Es)\n",
    "\n",
    "# Example seed list\n",
    "seed_list = [1314, 11314, 21314, 31314, 41314, 51314, 61314, 71314]\n",
    "\n",
    "# To store metrics across seeds\n",
    "all_best_test_acc = []\n",
    "all_last10_avg_acc = []\n",
    "\n",
    "########################################\n",
    "# Main Loop: Iterate over Seed List\n",
    "########################################\n",
    "\n",
    "for seed in seed_list:\n",
    "    logger.info(\"==========================================\")\n",
    "    logger.info(f\"Starting training for seed {seed}\")\n",
    "\n",
    "    # Set seeds for reproducibility\n",
    "    torch.manual_seed(seed)\n",
    "    np.random.seed(seed)\n",
    "    random.seed(seed)\n",
    "\n",
    "    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
    "    logger.info(f\"Running on device: {device}\")\n",
    "\n",
    "    # Load dataset\n",
    "    path = osp.join(data_path, dataset_name)\n",
    "    dataset = TUDataset(path, name=dataset_name)\n",
    "    dataset = list(dataset)\n",
    "    \n",
    "    # Reshape labels\n",
    "    for graph in dataset:\n",
    "        graph.y = graph.y.view(-1)\n",
    "    dataset = prepare_dataset_onehot_y(dataset)\n",
    "\n",
    "    # Log global dataset statistics\n",
    "    avg_num_nodes, max_num_nodes, avg_num_edges, avg_density, median_num_nodes, median_num_edges, median_density, min_num_nodes, std_num_nodes = stat_graph(dataset)\n",
    "    logger.info(f\"Dataset {dataset_name}: {len(dataset)} graphs, avg nodes: {avg_num_nodes}\")\n",
    "\n",
    "    # Shuffle and split dataset: 70% train, 10% validation (from train), 20% test\n",
    "    random.shuffle(dataset)\n",
    "    train_nums = int(len(dataset) * 0.7)\n",
    "    train_val_nums = int(len(dataset) * 0.8)\n",
    "    \n",
    "    # Log training subset statistics\n",
    "    avg_num_nodes, max_num_nodes, avg_num_edges, avg_density, median_num_nodes, median_num_edges, median_density, min_num_nodes, std_num_nodes = stat_graph(dataset[:train_nums])\n",
    "    logger.info(f\"Training graphs: avg nodes: {avg_num_nodes}, max nodes: {max_num_nodes}\")\n",
    "\n",
    "    resolution = int(avg_num_nodes)\n",
    "    logger.info(f\"Resolution set to: {resolution}\")\n",
    "    \n",
    "    # Augment dataset using graphon mixup if enabled\n",
    "    if gmixup:\n",
    "        class_graphs = split_class_graphs(dataset[:train_nums])\n",
    "\n",
    "        ###########################################\n",
    "        def compute_moment_scores(moment_matrix):\n",
    "            \"\"\"\n",
    "            Computes discrimination scores for each moment (column) in a matrix.\n",
    "            \n",
    "            Parameters:\n",
    "                moment_matrix (np.ndarray): A 2D NumPy array of shape (n_classes, m_moments)\n",
    "                                            where each row corresponds to a class.\n",
    "            \n",
    "            Returns:\n",
    "                std_scores (np.ndarray): Standard deviation scores for each moment.\n",
    "                range_scores (np.ndarray): Range (max-min) scores for each moment.\n",
    "                avg_pairwise (np.ndarray): Average pairwise difference scores for each moment.\n",
    "            \"\"\"\n",
    "            # Method 1: Standard deviation across classes for each moment.\n",
    "            std_scores = np.std(moment_matrix, axis=0)\n",
    "            \n",
    "            # Method 2: Range (maximum minus minimum) for each moment.\n",
    "            range_scores = np.ptp(moment_matrix, axis=0)  # np.ptp returns the range (peak to peak)\n",
    "            \n",
    "            # Method 3: Average pairwise absolute differences between classes for each moment.\n",
    "            n, m = moment_matrix.shape\n",
    "            avg_pairwise = np.zeros(m)\n",
    "            for j in range(m):\n",
    "                pairwise_diffs = []\n",
    "                for i in range(n):\n",
    "                    for k in range(i + 1, n):\n",
    "                        diff = abs(moment_matrix[i, j] - moment_matrix[k, j])\n",
    "                        pairwise_diffs.append(diff)\n",
    "                avg_pairwise[j] = np.mean(pairwise_diffs)\n",
    "                \n",
    "            return std_scores, range_scores, avg_pairwise\n",
    "\n",
    "\n",
    "\n",
    "        Es = [[[(0, 1)]], [[(0, 1), (1, 2)], [(0, 1), (0, 2), (1, 2)]], [[(0, 1), (1, 2), (2, 3)], [(0, 1), (0, 2), (0, 3)], [(0, 1), (0, 2), (1, 3), (2, 3)], [(0, 1), (0, 2), (0, 3), (1, 2)],\n",
    "                                                                        [(0, 1), (0, 2), (1, 2), (1, 3), (2, 3)], [(0, 1), (0, 2), (0, 3), (1, 2), (1, 3), (2, 3)]],\n",
    "            [[(0, 1), (1, 2), (2, 3), (3, 4)], [(0, 1), (0, 2), (0, 3), (3, 4)], [(0, 1), (0, 2), (0, 3), (0, 4)], [(0, 1), (0, 2), (1, 2), (1, 3), (2, 4)], [(0, 1), (0, 2), (1, 2), (0, 3), (3, 4)],\n",
    "            [(0, 1), (0, 2), (0, 3), (0, 4), (3, 4)], [(0, 1), (0, 2), (1, 3), (2, 4), (3, 4)], [(0, 1), (0, 2), (0, 3), (2, 4), (3, 4)], [(0, 1), (0, 2), (0, 3), (2, 4), (3, 4), (0, 4)], [(0, 1), (0, 2), (0, 3), (0, 4), (1, 2), (3, 4)] ,\n",
    "            [(0, 1), (0, 2), (0, 3), (2, 4), (3, 4), (2, 3)], [(0, 1), (0, 2), (0, 3), (1, 4), (2, 4), (3, 4)], [(0, 1), (0, 2), (1, 2), (1, 3), (2, 4), (3, 4)], [(0, 1), (0, 2), (1, 2), (1, 3), (2, 3), (1, 4), (2, 4)],\n",
    "            [(0, 1), (0, 2), (0, 3), (1, 2), (1, 3), (2, 3), (2, 4)], [(0, 1), (0, 2), (0, 3), (0, 4), (1, 2), (2, 3), (3, 4)], [(0, 1), (0, 2), (0, 3), (1, 2), (2, 3), (1, 4), (3, 4)], [(0, 1), (0, 2), (0, 3), (1, 3), (1, 2), (2, 3), (2, 4), (3, 4)],\n",
    "            [(0, 1), (0, 2), (0, 3), (0, 4), (1, 2), (1, 3), (2, 4), (3, 4)], [(0, 1), (0, 2), (0, 3), (0, 4), (1, 2), (2, 3), (1, 3), (2, 4), (3, 4)], [(0, 1), (0, 2), (0, 3), (0, 4), (1, 2), (1, 3), (1, 4), (2, 3), (2, 4), (3, 4)]\n",
    "            ]]\n",
    "\n",
    "        induced_list = motifs_to_induced_motifs(Es)\n",
    "\n",
    "        # Suppose we have data from 3 classes and 4 moments per class:\n",
    "        moment_matrix = []\n",
    "\n",
    "        for label, graphs in class_graphs:\n",
    "            estimated_densities = np.zeros(9)\n",
    "            nx_graphs = [nx.from_numpy_array(graph) for graph in graphs]\n",
    "            for graph in nx_graphs:\n",
    "                node_orbit_counts = orca(graph)\n",
    "                density = count2density(node_orbit_counts, graph.number_of_nodes())\n",
    "                estimated_densities += density\n",
    "            real_moments = estimated_densities / len(graphs)\n",
    "            moment_matrix.append(real_moments)\n",
    "        moment_matrix = np.array(moment_matrix)\n",
    "\n",
    "\n",
    "        std_scores, range_scores, avg_pairwise = compute_moment_scores(moment_matrix)\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "        ############################################\n",
    "        graphons = []\n",
    "        start_time = time()\n",
    "        for label, graphs in class_graphs:\n",
    "            logger.info(f\"Label {label}: {len(graphs)} graphs\")\n",
    "            num_estimate = int(aug_ratio * len(graphs))\n",
    "            inr_graphs = random.sample(graphs, num_estimate)\n",
    "            logger.info(f\"Selected {len(inr_graphs)} graphs for graphon estimation for label {label}\")\n",
    "            \n",
    "            if ge == \"ISGL\":\n",
    "                logger.info(\"Using ISGL for graphon estimation\")\n",
    "                gnn_dim_hidden = [8]\n",
    "                epoch_show = int(n_epochs_inr / 5)\n",
    "                inr_dim_hidden = [20, 20]\n",
    "                batch_size_inr = 1024\n",
    "                inr_lr = 0.01\n",
    "                inr_w = 10\n",
    "                # Call the function from your module (ensure this is defined)\n",
    "                model_ISGL_0, _ = coords_prediction(inr_dim_hidden, gnn_dim_hidden, int(2*n_epochs_inr), epoch_show, inr_w, inr_graphs, inr_lr)\n",
    "                X_all_0, y_all_0, w_all_0 = graph2XY(inr_graphs, model_ISGL_0)\n",
    "                logger.info(\"Number of datapoints for graphon estimation: {}\".format(X_all_0.shape[0]))\n",
    "                trained_inr_0 = train_graphon(inr_dim_hidden, inr_w, X_all_0, y_all_0, w_all_0, n_epochs_inr, epoch_show, inr_lr, batch_size_inr)\n",
    "                graphon = get_graphon(100, trained_inr_0, coords=None)\n",
    "                graphons.append((label, graphon, trained_inr_0))\n",
    "                \n",
    "            elif ge == \"Moment\":\n",
    "                logger.info(\"Using Moment for graphon estimation\")\n",
    "                import networkx as nx\n",
    "                nx_graphs = [nx.from_numpy_array(graph) for graph in inr_graphs]\n",
    "                trained_model = train_Moment(nx_graphs, 0, \"MLP\")\n",
    "                graphon = get_graphon(100, trained_model, coords=None)\n",
    "                graphons.append((label, graphon, trained_model))\n",
    "                \n",
    "            elif ge == \"IGNR\":\n",
    "                logger.info(\"Using IGNR for graphon estimation\")\n",
    "                gl_mlp = IGNR_pg_wrapper([20,20,20], w0=30)\n",
    "                loss = gl_mlp.train(inr_graphs, K='input', n_epoch=n_epochs_inr, f_sample='fixed')\n",
    "                W1 = gl_mlp.get_W(100)\n",
    "                graphons.append((label, W1, gl_mlp))\n",
    "                \n",
    "            else:\n",
    "                # Fallback: align graphs and compute SVD-based graphon\n",
    "                align_graphs_list, normalized_node_degrees, max_num, min_num = align_graphs(inr_graphs, padding=True, N=resolution)\n",
    "                logger.info(f\"Aligned graph shape: {align_graphs_list[0].shape}\")\n",
    "                graphon = universal_svd(align_graphs_list, threshold=0.2)\n",
    "                logger.info(f\"Graphon shape: {graphon.shape}\")\n",
    "                graphons.append((label, graphon, None))\n",
    "                \n",
    "        end_time = time()\n",
    "        num_classes = len(graphons)\n",
    "        logger.info(f\"Graphon estimation time per class: {(end_time - start_time)/num_classes} s\")\n",
    "\n",
    "        plt.figure(figsize=(int(1 + 3*num_classes), 3))\n",
    "        c = 1\n",
    "        for label, graphon, _ in graphons:\n",
    "            print(f\"graphon info: label:{label}; mean: {graphon.mean()}, shape, {graphon.shape}\")\n",
    "            plt.subplot(1, num_classes, c)\n",
    "            plt.imshow(graphon, cmap='hot', extent=[0, 1, 0, 1])\n",
    "            plt.xticks([])\n",
    "            plt.yticks([])\n",
    "            plt.title(r\"Class \" + str(c))\n",
    "            c += 1\n",
    "        plt.tight_layout()\n",
    "        plt.show()\n",
    "                \n",
    "        num_sample = int(train_nums * aug_ratio / aug_num)\n",
    "        lam_list = np.random.uniform(low=lam_range[0], high=lam_range[1], size=(aug_num,))\n",
    "        new_graph = []\n",
    "        for lam in lam_list:\n",
    "            logger.info(f\"lam: {lam}\")\n",
    "            logger.info(f\"num_sample: {num_sample}\")\n",
    "            two_graphons = random.sample(graphons, 2)\n",
    "            upper_bound = 600\n",
    "            lower_bound = 300\n",
    "            lower_bound = max(lower_bound, min_num_nodes)\n",
    "            new_graph += two_graphons_mixup(two_graphons, la=lam, num_sample=num_sample, ge=ge, resolution=[lower_bound, upper_bound])\n",
    "            logger.info(f\"New graph label: {new_graph[-1].y}\")\n",
    "            \n",
    "        avg_num_nodes, max_num_nodes, avg_num_edges, avg_density, median_num_nodes, median_num_edges, median_density, min_num_nodes, std_num_nodes = stat_graph(new_graph)\n",
    "        logger.info(f\"New graphs: avg nodes: {avg_num_nodes}, min nodes: {min_num_nodes}, max nodes: {max_num_nodes}\")\n",
    "        logger.info(f\"Real augmentation ratio: {len(new_graph)/train_nums}\")\n",
    "        dataset = new_graph + dataset\n",
    "        train_nums += len(new_graph)\n",
    "        train_val_nums += len(new_graph)\n",
    "\n",
    "    # Prepare node features for whole dataset\n",
    "    dataset = prepare_dataset_x(dataset)\n",
    "    logger.info(f\"Dataset feature shape: {dataset[0].x.shape}\")\n",
    "    logger.info(f\"Dataset label shape: {dataset[0].y.shape}\")\n",
    "\n",
    "    num_features = dataset[0].x.shape[1]\n",
    "    num_classes = dataset[0].y.shape[0]\n",
    "\n",
    "    for data in dataset:\n",
    "        if not hasattr(data, 'edge_attr') or data.edge_attr is None:\n",
    "            data.edge_attr = torch.ones((data.edge_index.size(1), 3)) \n",
    "\n",
    "    # Split dataset into train, validation and test\n",
    "    train_dataset = dataset[:train_nums]\n",
    "    random.shuffle(train_dataset)\n",
    "    val_dataset = dataset[train_nums:train_val_nums]\n",
    "    test_dataset = dataset[train_val_nums:]\n",
    "    \n",
    "    logger.info(f\"Train dataset size: {len(train_dataset)}\")\n",
    "    logger.info(f\"Validation dataset size: {len(val_dataset)}\")\n",
    "    logger.info(f\"Test dataset size: {len(test_dataset)}\")\n",
    "\n",
    "    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)\n",
    "    val_loader = DataLoader(val_dataset, batch_size=batch_size)\n",
    "    test_loader = DataLoader(test_dataset, batch_size=batch_size)\n",
    "\n",
    "    # Instantiate the model\n",
    "    if model_name == \"GIN\":\n",
    "        model_instance = GIN(num_features=num_features, num_classes=num_classes, num_hidden=num_hidden).to(device)\n",
    "    else:\n",
    "        logger.info(\"No valid model specified.\")\n",
    "        continue\n",
    "\n",
    "    optimizer = torch.optim.Adam(model_instance.parameters(), lr=learning_rate, weight_decay=5e-4)\n",
    "    scheduler = StepLR(optimizer, step_size=100, gamma=0.5)\n",
    "\n",
    "    # Train the model\n",
    "    last_10_epoch_acc = []  # To record test acc for the last 10 epochs\n",
    "    best_val_acc = 0\n",
    "    best_test_acc = 0\n",
    "\n",
    "    for epoch in range(1, num_epochs):\n",
    "        train_loss = train(model_instance, train_loader, optimizer, device, num_classes)\n",
    "        val_acc, val_loss = test(model_instance, val_loader, device, num_classes)\n",
    "        test_acc, test_loss = test(model_instance, test_loader, device, num_classes)\n",
    "        scheduler.step()\n",
    "\n",
    "        if val_acc >= best_val_acc:\n",
    "            best_val_acc = val_acc\n",
    "            best_test_acc = test_acc\n",
    "            best_epoch = epoch\n",
    "\n",
    "        if epoch > num_epochs - 10:\n",
    "            last_10_epoch_acc.append(test_acc)\n",
    "\n",
    "        if epoch % 30 == 0:\n",
    "            print('Epoch: {:03d}, Train Loss: {:.6f}, Val Loss: {:.6f}, Test Loss: {:.6f}, Val Acc: {:.6f}, Test Acc: {:.6f}'.format(\n",
    "                epoch, train_loss, val_loss, test_loss, val_acc, test_acc))\n",
    "    \n",
    "    avg_last10 = np.mean(last_10_epoch_acc)\n",
    "    print(f\"Seed {seed}: Last 10 epochs average test acc: {np.round(avg_last10,3)}\")\n",
    "    print(f\"Seed {seed}: Best test acc (based on validation) = {best_test_acc} at epoch {best_epoch}\")\n",
    "\n",
    "    all_best_test_acc.append(best_test_acc)\n",
    "    all_last10_avg_acc.append(avg_last10)\n",
    "\n",
    "########################################\n",
    "# After All Seeds: Print Final Metrics\n",
    "########################################\n",
    "print(\"==========================================\")\n",
    "print(\"Overall Results Across Seeds:\")\n",
    "print(f\"Best Test Accuracy: Mean = {np.mean(all_best_test_acc):.3f}, Std = {np.std(all_best_test_acc):.3f}\")\n",
    "print(f\"Last 10 Epoch Test Accuracy Averag e: Mean = {np.mean(all_last10_avg_acc):.3f}, Std = {np.std(all_last10_avg_acc):.3f}\")\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "base",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
