# dataset.py
# -*- coding: utf-8 -*-
"""
QM9 DataLoader helper (PyTorch Geometric).
- First-time use will auto-download & process.
- Standard split: train=110000, val=10000, test=rest.
- Clean DataLoader construction via functools.partial.

Usage:
    from dataset import get_qm9_dataloaders
    train_loader, val_loader, test_loader = get_qm9_dataloaders(
        root="./data/QM9", batch_size=64, seed=42, num_workers=4
    )
"""

from __future__ import annotations

import os
from functools import partial
from typing import Tuple

import torch
from torch.utils.data import random_split
from torch_geometric.datasets import QM9
from torch_geometric.loader import DataLoader


def get_qm9_dataloaders(
    root: str = "./data/QM9",
    batch_size: int = 256,
    seed: int = 42,
    num_workers: int = 4,
    pin_memory: bool = True,
) -> Tuple[DataLoader, DataLoader, DataLoader]:
    """
    Load (or auto-download) QM9 and return train/val/test DataLoaders.

    Args:
        root: Dataset root dir. If not existing, PyG will download & process on first run.
        batch_size: Per-iteration number of graphs.
        seed: RNG seed for reproducible random_split.
        num_workers: DataLoader worker count.
        pin_memory: Whether to pin memory for faster host->GPU transfer.

    Returns:
        (train_loader, val_loader, test_loader)
    """
    # 1) Load dataset (PyG will auto-download/process if needed)
    os.makedirs(root, exist_ok=True)
    dataset = QM9(root)
    n_total = len(dataset)  # typically 133885

    # 2) Standard split (train=110k, val=10k, test=rest), with safety checks
    n_train = min(110_000, n_total)
    remaining = n_total - n_train
    n_val = min(10_000, max(0, remaining))
    n_test = max(0, remaining - n_val)

    if n_val == 0 or n_test == 0:
        # Fallback: 80/10/10 if the dataset is somehow smaller
        n_train = int(0.8 * n_total)
        n_val = int(0.1 * n_total)
        n_test = n_total - n_train - n_val

    generator = torch.Generator().manual_seed(seed)
    train_set, val_set, test_set = random_split(dataset, [n_train, n_val, n_test], generator=generator)

    # 3) Partial for clean DataLoader config
    # Note: persistent_workers only works when num_workers > 0
    mkloader = partial(
        DataLoader,
        batch_size=batch_size,
        num_workers=num_workers,
        pin_memory=pin_memory,
        persistent_workers=bool(num_workers > 0),
    )

    train_loader = mkloader(train_set, shuffle=True)
    val_loader = mkloader(val_set, shuffle=False)
    test_loader = mkloader(test_set, shuffle=False)

    # Optional log
    print(f"[QM9] total={n_total} | train={len(train_set)} val={len(val_set)} test={len(test_set)}")
    return train_loader, val_loader, test_loader

def get_qm9_full_loader(
    root: str = "./data/QM9",
    batch_size: int = 256,
    shuffle: bool = False,
    num_workers: int = 4,
    pin_memory: bool = True,
) -> DataLoader:
    """
    Load (or auto-download) QM9 and return a single DataLoader over ALL molecules.

    Args:
        root: Dataset root dir.
        batch_size: Per-iteration number of graphs.
        shuffle: Whether to shuffle the whole dataset each epoch.
        num_workers: DataLoader worker count.
        pin_memory: Whether to pin memory for faster host->GPU transfer.

    Returns:
        loader over the full QM9 dataset.
    """
    os.makedirs(root, exist_ok=True)
    dataset = QM9(root)
    n_total = len(dataset)

    loader = DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=shuffle,
        num_workers=num_workers,
        pin_memory=pin_memory,
        persistent_workers=bool(num_workers > 0),
    )

    print(f"[QM9] total={n_total} | using full dataset in a single loader")
    return loader


__all__ = ["get_qm9_dataloaders", "get_qm9_full_loader"]
