import torch
from torch_geometric.data import Data
from ogb.nodeproppred import NodePropPredDataset, Evaluator
from transductive_training import train_network, init_model

# This file may be used for the training of OGBN-Mag dataset


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
num_workers = 0 if device.type == 'cpu' else 4
print(f"Using device: {device}")

# Load dataset
dataset = NodePropPredDataset(name='ogbn-mag')
graph, labels = dataset[0]

edge_index = torch.tensor(graph['edge_index_dict'][('paper', 'cites', 'paper')], dtype=torch.long).to(device)
x = torch.tensor(graph['node_feat_dict']['paper'], dtype=torch.float).to(device)
y = torch.tensor(labels['paper'], dtype=torch.long).to(device)
data = Data(x=x, edge_index=edge_index, y=y.squeeze())

split_idx = dataset.get_idx_split()
data.train_mask = torch.zeros(data.num_nodes, dtype=torch.bool)
data.train_mask[split_idx['train']['paper']] = True
data.val_mask = torch.zeros(data.num_nodes, dtype=torch.bool)
data.val_mask[split_idx['valid']['paper']] = True
data.test_mask = torch.zeros(data.num_nodes, dtype=torch.bool)
data.test_mask[split_idx['test']['paper']] = True
values = torch.ones(data.edge_index.shape[1], device=device)
adj = torch.sparse_coo_tensor(data.edge_index, values, (data.num_nodes, data.num_nodes)).to(device)
adj = adj + adj.t()
data.adj = adj
evaluator = Evaluator(name='ogbn-mag')

# Set methods, models and levels
n_runs = 3  # number of experiments of each setting
levels = [3]
methods = ["random"]
nets = ["GCN"]
p_vals = [1]  # use larger integers to enhance graph connectivity

for net in nets:
    model, optimizer = init_model(net, dataset, data, device)

    for k in p_vals:
        for n_levels in levels:
            for method in methods:
                if n_levels == 1:
                    n_fine_epochs = 2000
                elif n_levels == 2:
                    n_fine_epochs = 1000
                elif n_levels == 3:
                    n_fine_epochs = 800
                else:
                    # 4 levels
                    n_fine_epochs = 600

                print(f"number of fine epochs:", n_fine_epochs, flush=True)
                print(f"Model: {net}", flush=True)
                print(f"Using method {method} with {n_levels} levels and connectivity {k}:", flush=True)

                # Get results across multiple runs
                train_accs, val_accs, test_accs = [], [], []
                for run in range(n_runs):
                    model.reset_parameters()
                    run_max_train, run_max_val, run_max_test = \
                        train_network(n_levels, n_fine_epochs, model, data, optimizer, method, k, device, "ogbn-mag")
                    train_accs.append(run_max_train)
                    val_accs.append(run_max_val)
                    test_accs.append(run_max_test)

                # Compute overall statistics
                max_train = max(train_accs)
                max_val = max(val_accs)
                max_test = max(test_accs)
                std_train = torch.tensor(train_accs).std().item()
                std_val = torch.tensor(val_accs).std().item()
                std_test = torch.tensor(test_accs).std().item()

                # Print results
                print("\n--- Final Results Across Runs ---", flush=True)
                print(f"Max Train Accuracy: {100 * max_train:.2f}% (std: {100 * std_train:.2f}%)", flush=True)
                print(f"Max Validation Accuracy: {100 * max_val:.2f}% (std: {100 * std_val:.2f}%)", flush=True)
                print(f"Max Test Accuracy: {100 * max_test:.2f}% (std: {100 * std_test:.2f}%)", flush=True)
