import torch
import pickle
import numpy as np
# import sys
# sys.path.append("..")
from params import args
from data_handler import DataHandler

args.devices = [f'cuda:{args.gpu}', f'cuda:{args.gpu}']
print(args)
for dataset in ['cora', 'arxiv', 'pubmed', 'home', 'tech']:
    handler_arxiv = DataHandler(f'{dataset}')
    labels = handler_arxiv.tst_loader.dataset.labels
    
    num_samples = labels.shape[0]
    print(f"total number of samples: {num_samples}")
    
    classes, counts = np.unique(labels, return_counts=True)
    num_classes = classes.shape[0]
    print(f"class number: {num_classes}")
    
    for cls, cnt in zip(classes, counts):
        print(f" class {cls}: {cnt} samples")