import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.data import DataLoader
from gcn import GCN
from ogb.graphproppred import PygGraphPropPredDataset
from inductive_training import *


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
num_workers = 0 if device.type == 'cpu' else 6
print(f"Using device: {device}")

path = "data/inductive"

n_runs = 3
levels = [2]
methods = ["random"]
nets = ["GCN"]
ks = [1]


dataset = PygGraphPropPredDataset(name="ogbg-molhiv", root=path)
split_idx = dataset.get_idx_split()
train_dataset = dataset[split_idx["train"]]
val_dataset = dataset[split_idx["valid"]]
test_dataset = dataset[split_idx["test"]]
train_loader = DataLoader(train_dataset, batch_size=1, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False)
dt = "MolHIV"

model = GCN(in_channels=train_dataset.x.size(1),
            hidden_channels=192,
            out_channels=train_dataset.num_classes,
            num_layers=4,
            dropout=0.5).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.005)

for net in nets:
    for k in ks:
        for n_levels in levels:
            for method in methods:

                if n_levels == 1:
                    n_fine_epochs = 100
                else:
                    # 2 levels
                    n_fine_epochs = 40

                print(f"number of fine epochs:", n_fine_epochs, flush=True)
                print(f"Using method {method} with {n_levels} levels and k {k}:", flush=True)

                # Get results across multiple runs
                test_aucrocs = []
                for run in range(n_runs):
                    model.reset_parameters()
                    run_max_test_aucroc = train_network(n_levels, n_fine_epochs, model, train_loader, test_loader,
                                                        optimizer, method, k, device, dt)
                    test_aucrocs.append(run_max_test_aucroc)

                # Compute overall statistics
                max_test_aucroc = max(test_aucrocs)
                std_test_aucroc = torch.tensor(test_aucrocs).std().item()

                # Print results
                print("\n--- Final Results Across Runs ---", flush=True)
                print(f"Max Test AUCROC: {100 * max_test_aucroc:.2f}% (std: {100 * std_test_aucroc:.2f}%)", flush=True)