import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.data import DataLoader
from gcn import GCN
from torch_geometric.datasets import TUDataset
from inductive_training import *


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
num_workers = 0 if device.type == 'cpu' else 6
print(f"Using device: {device}")

path = "data/inductive"

n_runs = 3
levels = [2]
methods = ["random"]
nets = ["GCN"]
ks = [1]

dataset = TUDataset(root=path, name="NCI1")
train_dataset, test_dataset = dataset[:3000], dataset[3000:]
train_loader = DataLoader(train_dataset, batch_size=1, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False)
dt = "NCI1"

model = GCN(in_channels=train_dataset.x.size(1),
            hidden_channels=192,
            out_channels=train_dataset.num_classes,
            num_layers=4,
            dropout=0.5).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.005)

for net in nets:
    for k in ks:
        for n_levels in levels:
            for method in methods:

                if n_levels == 1:
                    n_fine_epochs = 100
                else:
                    # 2 levels
                    n_fine_epochs = 40
                    # n_fine_epochs = 1

                print(f"number of fine epochs:", n_fine_epochs, flush=True)
                print(f"Using method {method} with {n_levels} levels and k {k}:", flush=True)

                # Get results across multiple runs
                test_accs = []
                for run in range(n_runs):
                    model.reset_parameters()
                    run_max_test_acc = train_network(n_levels, n_fine_epochs, model, train_loader, test_loader,
                                                     optimizer, method, k, device, dt)
                    test_accs.append(run_max_test_acc)

                # Compute overall statistics
                max_test_acc = max(test_accs)
                std_test_acc = torch.tensor(test_accs).std().item()

                # Print results
                print("\n--- Final Results Across Runs ---", flush=True)
                print(f"Max Test Accuracy: {100 * max_test_acc:.2f}% (std: {100 * std_test_acc:.2f}%)", flush=True)