# forward_forward/data/dataloader.py

import torch
from torch.utils.data import DataLoader
from torchvision import datasets


def get_base_dataset(name: str, train: bool = True, transform=None, target_transform=None):
    """
    Return a base dataset for the specified name.
    Args:
        name (str): Dataset name.
        train (bool): Whether to load the training set.
        transform (callable, optional): Input transform.
        target_transform (callable, optional): Label transform.
    Returns:
        Dataset: The requested dataset.
    """
    name = name.lower()

    if name == "mnist":
        dataset = datasets.MNIST(root="~/datasets", train=train, download=True, transform=transform, target_transform=target_transform)
        dataset.input_shape = (1, 28, 28)
        dataset.num_classes = 10

    elif name == "cifar10":
        dataset = datasets.CIFAR10(root="~/datasets", train=train, download=True, transform=transform, target_transform=target_transform)
        dataset.input_shape = (3, 32, 32)
        dataset.num_classes = 10

    elif name == "cifar100":
        dataset = datasets.CIFAR100(root="~/datasets", train=train, download=True, transform=transform, target_transform=target_transform)
        dataset.input_shape = (3, 32, 32)
        dataset.num_classes = 100

    else:
        raise ValueError(f"Unsupported dataset: {name}")

    return dataset


def get_train_val_datasets(dataset_name, transform, val_fraction, seed):
    """
    Split the dataset ONCE to avoid data leakage.
    """
    full_dataset = get_base_dataset(dataset_name, train=True, transform=transform)
    val_size = int(len(full_dataset) * val_fraction)
    train_size = len(full_dataset) - val_size
    generator = torch.Generator().manual_seed(seed)
    train_dataset, val_dataset = torch.utils.data.random_split(full_dataset, [train_size, val_size], generator=generator)
    return train_dataset, val_dataset


def get_dataloader(
    dataset_or_name,
    batch_size: int,
    transform=None,
    target_transform=None,
    split: str = "train",  # "train", "val", or "test"
    val_fraction: float = 0.1,
    seed: int = 42,
    shuffle: bool = True,
    drop_last: bool = True
) -> DataLoader:
    """
    Return a DataLoader for the specified split of the dataset.

    Args:
        dataset_name (str): Dataset name.
        batch_size (int): Batch size.
        transform (callable, optional): Input transform.
        target_transform (callable, optional): Label transform.
        split (str): One of "train", "val", or "test".
        val_fraction (float): Fraction of training data to reserve for validation.
        seed (int): Random seed for reproducibility of train/val split.
        shuffle (bool): Whether to shuffle the data.
        drop_last (bool): Whether to drop the last incomplete batch.

    Returns:
        DataLoader: Configured dataloader for the specified split.
    """
    
    """
    If dataset_or_name is a Dataset, just wrap in DataLoader. If str, proceed as before.
    """
    if isinstance(dataset_or_name, torch.utils.data.Dataset):
        return DataLoader(dataset_or_name, batch_size=batch_size, shuffle=(split=="train"), drop_last=drop_last)
    
    # else fall back to original logic for 'test'
    split = split.lower()
    if split == "test":
        dataset = get_base_dataset(dataset_or_name, train=False, transform=transform, target_transform=target_transform)
        return DataLoader(dataset, batch_size=batch_size, shuffle=False, drop_last=drop_last)
    else:
        raise RuntimeError("For train/val, pass a Dataset object, not a name string.")
