import os
from torch_geometric.loader import DataLoader
from paper_dataset import PaperDataset


DATA_BASE_PATH = "./data/paper/"

TRAIN_NPY_PATH = os.path.join(DATA_BASE_PATH, "train_casual_1_3.npy")
VAL_NPY_PATH   = os.path.join(DATA_BASE_PATH, "val_casual_1_3.npy")
TEST_NPY_PATH  = os.path.join(DATA_BASE_PATH, "test_casual_1_3.npy")

def load_datasets():
    
    try:
        train_dataset = PaperDataset(TRAIN_NPY_PATH)
        val_dataset   = PaperDataset(VAL_NPY_PATH)
        test_dataset  = PaperDataset(TEST_NPY_PATH)
        print(f"Datasets loaded: Train={len(train_dataset)} graphs, Val={len(val_dataset)} graphs, Test={len(test_dataset)} graphs")
        return train_dataset, val_dataset, test_dataset
    except Exception as e:
        print(f"Error loading datasets: {e}")
        return None, None, None

def create_dataloaders(train_dataset, val_dataset, test_dataset, batch_size=1):
    
    if not all([train_dataset, val_dataset, test_dataset]):
        print("Cannot create DataLoaders because dataset loading failed.")
        return None, None, None
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader   = DataLoader(val_dataset,   batch_size=batch_size, shuffle=False)
    test_loader  = DataLoader(test_dataset,  batch_size=batch_size, shuffle=False)
    print("DataLoaders created.")
    return train_loader, val_loader, test_loader

if __name__ == "__main__":
    train_ds, val_ds, test_ds = load_datasets()
    train_loader, val_loader, test_loader = create_dataloaders(train_ds, val_ds, test_ds)
