from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import torch
import numpy as np
import random

def seed_worker(worker_id):
    worker_seed = torch.initial_seed() % 2**32
    np.random.seed(worker_seed)
    random.seed(worker_seed)

def get_dataloader(name, batch_size, img_size, root='../data', train=True, download=False, num_workers=0):
    if name.lower() == 'mnist':
        return get_mnist_dataloader(batch_size, img_size, root, train, download, num_workers)
    elif name.lower() == 'cifar10':
        return get_cifar10_dataloader(batch_size, img_size, root, train, download, num_workers)
    elif name.lower() == 'afhqv2':
        return get_afhqv2_dataloader(batch_size, img_size, root, train, download, num_workers)
    else:
        raise ValueError(f"Unknown dataset: {name}")

def get_afhqv2_dataloader(batch_size, img_size=(256, 256), root='../data', train=True, download=False, num_workers=0):
    """
    Creates a DataLoader for AFHQv2 dataset.
    Assumes dataset is at root/afhqv2/{train,test}
    """
    import os
    from torchvision.datasets.utils import download_and_extract_archive
    
    # Define dataset path
    data_root = os.path.join(root, 'afhq_v2', 'train' if train else 'test')
    
    if not os.path.exists(data_root):
        if download:
            url = "https://www.dropbox.com/s/vkzjokiwof5h8w6/afhq_v2.zip?dl=1"
            try:
                print(f"Downloading AFHQv2 from {url}...")
                download_and_extract_archive(url, root, filename="afhq_v2.zip", remove_finished=True)
            except Exception as e:
                raise RuntimeError(
                    f"Auto-download failed due to network issues: {e}\n"
                    f"Please manually download the dataset from: {url}\n"
                    f"Then extract it to '{root}' so that the directory '{os.path.join(root, 'afhq_v2')}' exists."
                )
        else:
            raise RuntimeError(f"AFHQv2 dataset not found at {data_root}. Please download and extract it there.")

    if train:
        transform = transforms.Compose([
            transforms.RandomResizedCrop(img_size, scale=(0.8, 1.0), ratio=(0.9, 1.1), interpolation=transforms.InterpolationMode.BILINEAR),
            # transforms.Resize(img_size, interpolation=transforms.InterpolationMode.BILINEAR),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
        ])
    else:
        transform = transforms.Compose([
            transforms.Resize(img_size, interpolation=transforms.InterpolationMode.BILINEAR),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
        ])

    dataset = datasets.ImageFolder(root=data_root, transform=transform)
    
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=train, 
                            num_workers=num_workers, drop_last=train, worker_init_fn=seed_worker)
    
    return dataloader

def get_mnist_dataloader(batch_size, img_size=(64, 64), root='../data', train=True, download=False, num_workers=0):
    """
    Creates a DataLoader for MNIST dataset.
    """
    transform = transforms.Compose([
        transforms.Resize(img_size, interpolation=transforms.InterpolationMode.BILINEAR),
        transforms.ToTensor(),
        transforms.Normalize((0.5,), (0.5,))
    ])

    dataset = datasets.MNIST(root=root, train=train, download=download, transform=transform)
    
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=train, 
                            num_workers=num_workers, drop_last=train, worker_init_fn=seed_worker)
    
    return dataloader

def get_cifar10_dataloader(batch_size, img_size=(64, 64), root='../data', train=True, download=False, num_workers=0):
    """
    Creates a DataLoader for CIFAR10 dataset.
    """
    transform = transforms.Compose([
        transforms.Resize(img_size, interpolation=transforms.InterpolationMode.BILINEAR),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])

    dataset = datasets.CIFAR10(root=root, train=train, download=download, transform=transform)
    
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=train, 
                            num_workers=num_workers, drop_last=train, worker_init_fn=seed_worker)
    
    return dataloader

if __name__ == "__main__":
    # Test the dataloader
    dl = get_mnist_dataloader(128, img_size=(64, 64), download=True)
    batch, labels = next(iter(dl))
    print(f"Batch shape: {batch.shape}")
    print(f"Min value: {batch.min()}, Max value: {batch.max()}")
