import torch
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Dataset
import os
from PIL import Image


def get_cifar10_transforms():
    """Standard transforms for CIFAR-10 dataset"""
    train_transform = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465),
                             (0.2470, 0.2435, 0.2616))
    ])

    test_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465),
                             (0.2470, 0.2435, 0.2616))
    ])

    return train_transform, test_transform


def get_cifar100_transforms():
    """Enhanced transforms for CIFAR-100 dataset"""
    train_transform = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.RandomRotation(15),
        transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
        transforms.ToTensor(),
        transforms.Normalize((0.5071, 0.4867, 0.4408),
                             (0.2675, 0.2565, 0.2761)),
        transforms.RandomErasing(p=0.5, scale=(0.02, 0.33))
    ])

    test_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5071, 0.4867, 0.4408),
                             (0.2675, 0.2565, 0.2761))
    ])

    return train_transform, test_transform


def get_mnist_transforms():
    """Standard transforms for MNIST dataset"""
    train_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
    ])

    test_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
    ])

    return train_transform, test_transform


def get_svhn_transforms():
    """Standard transforms for SVHN dataset"""
    train_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4380, 0.4440, 0.4730),
                             (0.1751, 0.1771, 0.1744))
    ])

    test_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4380, 0.4440, 0.4730),
                             (0.1751, 0.1771, 0.1744))
    ])

    return train_transform, test_transform


def get_tinyimagenet_transforms():
    """Transforms for TinyImageNet dataset"""
    train_transform = transforms.Compose([
        transforms.RandomCrop(64, padding=8),
        transforms.RandomHorizontalFlip(),
        transforms.RandAugment(num_ops=2, magnitude=9),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        transforms.RandomErasing(p=0.5, scale=(0.02, 0.33)),
    ])

    test_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])

    return train_transform, test_transform


def load_cifar10(batch_size=128, num_workers=4, data_root='./data'):
    """Load CIFAR-10 dataset"""
    train_transform, test_transform = get_cifar10_transforms()

    trainset = torchvision.datasets.CIFAR10(
        root=data_root, train=True, download=True, transform=train_transform
    )
    testset = torchvision.datasets.CIFAR10(
        root=data_root, train=False, download=True, transform=test_transform
    )

    train_loader = DataLoader(
        trainset, batch_size=batch_size, shuffle=True, num_workers=num_workers
    )
    test_loader = DataLoader(
        testset, batch_size=batch_size, shuffle=False, num_workers=num_workers
    )

    return train_loader, test_loader, 10


def load_cifar100(batch_size=128, num_workers=4, data_root='./data'):
    """Load CIFAR-100 dataset"""
    train_transform, test_transform = get_cifar100_transforms()

    trainset = torchvision.datasets.CIFAR100(
        root=data_root, train=True, download=True, transform=train_transform
    )
    testset = torchvision.datasets.CIFAR100(
        root=data_root, train=False, download=True, transform=test_transform
    )

    train_loader = DataLoader(
        trainset, batch_size=batch_size, shuffle=True, num_workers=num_workers
    )
    test_loader = DataLoader(
        testset, batch_size=batch_size, shuffle=False, num_workers=num_workers
    )

    return train_loader, test_loader, 100


def load_mnist(batch_size=128, num_workers=4, data_root='./data'):
    """Load MNIST dataset"""
    train_transform, test_transform = get_mnist_transforms()

    trainset = torchvision.datasets.MNIST(
        root=data_root, train=True, download=True, transform=train_transform
    )
    testset = torchvision.datasets.MNIST(
        root=data_root, train=False, download=True, transform=test_transform
    )

    train_loader = DataLoader(
        trainset, batch_size=batch_size, shuffle=True, num_workers=num_workers
    )
    test_loader = DataLoader(
        testset, batch_size=batch_size, shuffle=False, num_workers=num_workers
    )

    return train_loader, test_loader, 10


def load_svhn(batch_size=128, num_workers=4, data_root='./data/svhn'):
    """Load SVHN dataset"""
    train_transform, test_transform = get_svhn_transforms()

    trainset = torchvision.datasets.SVHN(
        root=data_root, split='train', download=True, transform=train_transform
    )
    testset = torchvision.datasets.SVHN(
        root=data_root, split='test', download=True, transform=test_transform
    )

    train_loader = DataLoader(
        trainset, batch_size=batch_size, shuffle=True, num_workers=num_workers
    )
    test_loader = DataLoader(
        testset, batch_size=batch_size, shuffle=False, num_workers=num_workers
    )

    return train_loader, test_loader, 10


class TinyImageNetValDataset(Dataset):
    def __init__(self, val_dir, class_to_idx, transform=None):
        self.val_dir = val_dir
        self.transform = transform
        self.class_to_idx = class_to_idx
        self.image_paths = []
        self.labels = []

        val_annotations_path = os.path.join(val_dir, 'val_annotations.txt')
        with open(val_annotations_path, 'r') as f:
            for line in f.readlines():
                parts = line.strip().split('\t')
                img_name, class_id = parts[0], parts[1]
                self.image_paths.append(os.path.join(val_dir, 'images', img_name))
                self.labels.append(self.class_to_idx[class_id])

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image = Image.open(self.image_paths[idx]).convert('RGB')
        label = self.labels[idx]
        if self.transform:
            image = self.transform(image)
        return image, label


def load_tinyimagenet(batch_size=128, num_workers=4, data_root='./data/tiny-imagenet-200'):
    """Load TinyImageNet dataset"""
    train_transform, test_transform = get_tinyimagenet_transforms()

    train_dir = os.path.join(data_root, 'train')
    val_dir = os.path.join(data_root, 'val')

    # Create training dataset using ImageFolder
    trainset = torchvision.datasets.ImageFolder(
        root=train_dir, transform=train_transform
    )

    # Create a mapping from class ID strings to integer indices based on the training set
    class_to_idx = trainset.class_to_idx

    # Create custom validation dataset
    testset = TinyImageNetValDataset(
        val_dir=val_dir, class_to_idx=class_to_idx, transform=test_transform
    )

    train_loader = DataLoader(
        trainset, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=True
    )
    test_loader = DataLoader(
        testset, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=True
    )

    return train_loader, test_loader, 200


def load_dataset(dataset_name, batch_size=128, num_workers=4):
    """Load dataset based on dataset name"""
    dataset_loaders = {
        'cifar10': load_cifar10,
        'cifar100': load_cifar100,
        'mnist': load_mnist,
        'svhn': load_svhn,
        'tinyimagenet': load_tinyimagenet
    }

    if dataset_name.lower() not in dataset_loaders:
        raise ValueError(
            f"Unsupported dataset: {dataset_name}. Supported datasets: {list(dataset_loaders.keys())}")

    return dataset_loaders[dataset_name.lower()](
        batch_size=batch_size, num_workers=num_workers,
    )