# data_utils.py
import torch
import torchvision
import torchvision.transforms as transforms
import logging

logger = logging.getLogger(__name__)


def get_dataset_transforms(dataset_name):
    """Get dataset-specific transforms."""
    if dataset_name == 'cifar10':
        transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261)),
        ])
        input_channels = 3
        dataset_class = torchvision.datasets.CIFAR10
    elif dataset_name == 'cifar100':
        transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)),
        ])
        input_channels = 3
        dataset_class = torchvision.datasets.CIFAR100
    elif dataset_name == 'mnist':
        transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.1307,), (0.3081,)),
            transforms.Lambda(lambda x: x.expand(3, -1, -1)),  # Convert grayscale to RGB
        ])
        input_channels = 3  # Always 3 since we're expanding grayscale to RGB
        dataset_class = torchvision.datasets.MNIST
    else:
        raise ValueError(f"Unknown dataset: {dataset_name}")
    return transform, input_channels, dataset_class


def load_datasets(dataset_class, transform, data_root='./data'):
    """Load training and test datasets."""
    logger.info(f"Loading {dataset_class.__name__} from {data_root}...")
    try:
        train_dataset = dataset_class(root=data_root, train=True, download=True, transform=transform)
        test_dataset = dataset_class(root=data_root, train=False, download=True, transform=transform)
    except Exception as e:
        logger.error(f"Failed to download/load dataset. Check network or permissions for {data_root}. Error: {e}")
        raise
    return train_dataset, test_dataset


def get_num_classes(dataset_obj):
    """Determine the number of classes from the dataset object."""
    if hasattr(dataset_obj, 'classes') and dataset_obj.classes is not None:
        return len(dataset_obj.classes)
    elif hasattr(dataset_obj, 'targets') and dataset_obj.targets is not None:
        return len(torch.unique(torch.tensor(dataset_obj.targets)))
    else:
        # Fallback for MNIST if .classes is not standard
        if isinstance(dataset_obj, torchvision.datasets.MNIST):
            return 10
        raise ValueError("Could not determine number of classes from dataset.")


def create_data_loaders(train_dataset, test_dataset, batch_size, num_workers, device_type='cpu'):
    """Create data loaders for training and testing."""
    pin_memory = True if device_type == 'cuda' else False
    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=num_workers,
        pin_memory=pin_memory,
        drop_last=True
    )
    test_loader = torch.utils.data.DataLoader(
        test_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_workers,
        pin_memory=pin_memory
    )
    logger.info(f"Train DataLoader: {len(train_loader)} batches of size {batch_size}.")
    logger.info(f"Test DataLoader: {len(test_loader)} batches of size {batch_size}.")
    return train_loader, test_loader
