import torch
from data_utils import get_mutag_loaders
from model import NodeHaarUnpoolClassifier
from train_eval import train_model

def main():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    train_loader, val_loader, test_loader, dataset = get_mutag_loaders(batch_size=60, seed=42)
    model = NodeHaarClassifierAggressive   (
    in_dim=in_dim,
    hid_dim=hid_dim,
    num_classes=num_classes,
    max_K=max_K,
    num_levels=num_levels_for_model
).to(device)
    if torch.cuda.device_count() > 1:
        model = torch.nn.DataParallel(model)
    train_model(model, train_loader, val_loader, test_loader, device=device, epochs=150, levels=4)

if __name__ == "__main__":
    main()

