

import torch
from torch_geometric.loader import DataLoader
from molecular_dataset import MolecularDataset 
import os





DATA_BASE_PATH = "./data/molecular/" 


TRAIN_NPY_PATH = DATA_BASE_PATH + "train_crcg.npy"
VAL_NPY_PATH = DATA_BASE_PATH + "val_crcg.npy"
TEST_NPY_PATH = DATA_BASE_PATH + "test_crcg.npy"

BATCH_SIZE = 32 


def load_datasets():
    
    try:
        print("Abs path:", os.path.abspath(TRAIN_NPY_PATH))

        print(f"Loading training data from: {TRAIN_NPY_PATH}")
        train_dataset = MolecularDataset(npy_path=TRAIN_NPY_PATH)
        print(f"Loading validation data from: {VAL_NPY_PATH}")
        val_dataset = MolecularDataset(npy_path=VAL_NPY_PATH)
        print(f"Loading test data from: {TEST_NPY_PATH}")
        test_dataset = MolecularDataset(npy_path=TEST_NPY_PATH)

        print(f"\nDatasets loaded successfully:")
        print(f"  Training set: {len(train_dataset)} graphs, {train_dataset.num_features} features")
        print(f"  Validation set: {len(val_dataset)} graphs")
        print(f"  Test set: {len(test_dataset)} graphs")

        return train_dataset, val_dataset, test_dataset

    except FileNotFoundError as e:
        print(f"Error: Dataset file not found. {e}")
        print(f"Please ensure the .npy files exist at the specified paths:")
        print(f"  {TRAIN_NPY_PATH}")
        print(f"  {VAL_NPY_PATH}")
        print(f"  {TEST_NPY_PATH}")
        print(f"And that DATA_BASE_PATH is set correctly in this script.")
        return None, None, None
    except (KeyError, ValueError, TypeError, Exception) as e:
        print(f"Error loading or processing dataset: {e}")
        return None, None, None


def create_dataloaders(train_dataset, val_dataset, test_dataset, batch_size):
    
    if not all([train_dataset, val_dataset, test_dataset]):
        print("Cannot create DataLoaders because dataset loading failed.")
        return None, None, None

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

    print("\nDataloaders created.")
    return train_loader, val_loader, test_loader


if __name__ == "__main__":
    train_ds, val_ds, test_ds = load_datasets()

    if train_ds:
        
        
        num_features = train_ds.num_features
        
        
        
        max_label = 0
        if train_ds.data.y is not None: max_label = max(max_label, train_ds.data.y.max().item())
        if val_ds and val_ds.data.y is not None: max_label = max(max_label, val_ds.data.y.max().item())
        if test_ds and test_ds.data.y is not None: max_label = max(max_label, test_ds.data.y.max().item())
        num_classes = int(max_label) + 1

        print(f"\nDataset properties:")
        print(f"  Number of node features: {num_features}")
        print(f"  Number of classes: {num_classes}")

        
        train_loader, val_loader, test_loader = create_dataloaders(train_ds, val_ds, test_ds, BATCH_SIZE)

        
        
        
        
        
        
        
        
        
        
        
        
        
        
        

        print("\nIntegration example complete. You can now integrate these loaders into your CaNet training script.")

