from gnnboundary import *
from tqdm import tqdm




def collab():
    dataset = CollabDataset(seed=12345)
    model = GCNClassifier(node_features=len(dataset.NODE_CLS),
                              num_classes=len(dataset.GRAPH_CLS),
                              hidden_channels=64,
                              num_layers=5)

    for epoch in tqdm(range(100)):
        train, val = dataset.train_test_split(k=10)
        train_loss = train.model_fit(model, lr=0.001)
        train_metrics = train.model_evaluate(model)
        val_metrics = val.model_evaluate(model)
        print(f"Epoch: {epoch:03d}, "
              f"Train Loss: {train_loss:.4f}, "
              f"Train Acc: {train_metrics['acc']:.4f}, "
              f"Test Acc: {val_metrics['acc']:.4f}, "
              f"Train F1: {train_metrics['f1']}, "
              f"Test F1: {val_metrics['f1']}")



if __name__ == "__main__":
    collab()
    print("Done!")