import os
import torch
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import torchvision.transforms as transforms
import torchvision.datasets as datasets


class CelebAHQDataset(Dataset):
    """Dataset for CelebA-HQ images."""

    def __init__(self, root_dir: str, transform=None):
        self.root_dir = root_dir
        self.transform = transform or transforms.Compose(
            [
                transforms.Resize((64, 64)),
                transforms.ToTensor(),
                transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
            ]
        )

        # Get all image files
        self.image_files = sorted(
            [f for f in os.listdir(root_dir) if f.endswith((".png", ".jpg", ".jpeg"))]
        )

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        img_path = os.path.join(self.root_dir, self.image_files[idx])
        image = Image.open(img_path).convert("RGB")

        if self.transform:
            image = self.transform(image)

        return image, 0  # Return 0 as label to match CelebA interface


def get_celeba_hq_dataloader(
    root_dir: str = "data/celebahq-resized-256x256/versions/1/celeba_hq_256",
    num_images: int = None,
    resolution: int = 128,
    batch_size: int = 16,
) -> DataLoader:
    # Define transforms
    transform = transforms.Compose(
        [
            transforms.ToTensor(),
            transforms.Resize((resolution, resolution)),
            # transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))  # Normalize to [-1, 1]
        ]
    )

    # Create dataset
    dataset = CelebAHQDataset(root_dir=root_dir, transform=transform)

    # Create subset if num_images is specified
    if num_images is not None:
        from torch.utils.data import Subset

        # Use fixed seed generator for reproducible subset selection
        seed = 42
        torch.manual_seed(seed)
        g = torch.Generator().manual_seed(42)
        dataset = Subset(dataset, range(num_images))

    # Create dataloader
    return DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=4,
        pin_memory=True,
        generator=g
        if num_images is not None
        else None,  # Use fixed generator if subset
        drop_last=False,
    )


def get_cifar10_dataloader(num_images: int = None, batch_size: int = 80) -> DataLoader:
    """Returns DataLoader for CIFAR10 dataset."""
    transform = transforms.Compose(
        [
            transforms.ToTensor(),
        ]
    )

    dataset = torchvision.datasets.CIFAR10(
        root="./data", train=True, download=True, transform=transform
    )

    if num_images is not None:
        from torch.utils.data import Subset

        # Use fixed seed generator for reproducible subset selection
        seed = 42
        torch.manual_seed(seed)
        g = torch.Generator().manual_seed(42)  # Match seed from training_utils.py
        dataset = Subset(dataset, range(num_images))

    return DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=4,
        pin_memory=True,
        generator=g,  # Added fixed generator for reproducible shuffling
        drop_last=False,  # Added to match training_utils.py
    )


def get_celeba_dataloader(
    num_images: int = None, batch_size: int = 80, resolution: int = 128
) -> DataLoader:
    """Returns DataLoader for CelebA dataset."""
    transform = transforms.Compose(
        [
            transforms.Resize((resolution, resolution)),  # Resize to a reasonable size
            # transforms.CenterCrop(128),    # Center crop to ensure square images
            transforms.ToTensor(),
        ]
    )

    dataset = torchvision.datasets.CelebA(
        root="./data", split="train", download=True, transform=transform
    )

    if num_images is not None:
        from torch.utils.data import Subset

        # Use fixed seed generator for reproducible subset selection
        seed = 42
        torch.manual_seed(seed)
        g = torch.Generator().manual_seed(42)  # Match seed from training_utils.py
        dataset = Subset(dataset, range(num_images))

    return DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=4,
        pin_memory=True,
        generator=g,  # Added fixed generator for reproducible shuffling
        drop_last=False,  # Added to match training_utils.py
    )


class FFHQDataset(Dataset):
    """Dataset for FFHQ images."""

    def __init__(self, root_dir: str, resolution: int = 64):
        self.root_dir = root_dir
        self.transform = transforms.Compose(
            [
                transforms.Resize((resolution, resolution)),
                transforms.ToTensor(),
                transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
            ]
        )

        # Get all image files
        self.image_files = sorted(
            [f for f in os.listdir(root_dir) if f.endswith((".png", ".jpg", ".jpeg"))]
        )

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        img_path = os.path.join(self.root_dir, self.image_files[idx])
        image = Image.open(img_path).convert("RGB")
        return self.transform(image), 0


def get_ffhq_dataloader(
    dataset_dir: str, num_images: int = None, batch_size: int = 32
) -> DataLoader:
    """Returns DataLoader for FFHQ dataset."""
    dataset = FFHQDataset(dataset_dir, limit=num_images)

    return DataLoader(
        dataset, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True
    )


def get_dataset_loader(
    dataset_name: str,
    num_images: int = None,
    dataroot=".",
    batch_size: int = 80,
    **kwargs,
) -> DataLoader:
    """
    Get a DataLoader for the specified dataset.

    Args:
        dataset_name: Name of the dataset ('mnist', 'fashion_mnist', 'cifar10', 'ffhq', 'celeba_hq', 'afhq')
        num_images: Number of images to use (if None, use all)
        batch_size: Batch size for the DataLoader
        **kwargs: Additional arguments specific to each dataset
    """
    # Set fixed seed for reproducibility
    seed = 42
    g = torch.Generator().manual_seed(seed)

    # Common transforms for grayscale datasets
    grayscale_transform = transforms.Compose(
        [transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))]
    )

    # Common transforms for RGB datasets
    rgb_transform = transforms.Compose(
        [transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
    )

    if dataset_name == "mnist":
        dataset = datasets.MNIST(
            root=f"{dataroot}/data",
            train=True,
            download=True,
            transform=grayscale_transform,
        )
    elif dataset_name == "fashion_mnist":
        dataset = datasets.FashionMNIST(
            root=f"{dataroot}/data",
            train=True,
            download=True,
            transform=grayscale_transform,
        )
    elif dataset_name == "cifar10":
        dataset = datasets.CIFAR10(
            root=f"{dataroot}/data", train=True, download=True, transform=rgb_transform
        )
    elif dataset_name == "ffhq":
        ffhq_dir = kwargs.get(
            "ffhq_dir",
            "data/ffhq_70k",
        )
        dataset = FFHQDataset(root_dir=ffhq_dir)
    elif dataset_name == "celeba_hq":
        celeba_dir = kwargs.get(
            "celeba_dir",
            "data/celebahq-resized-256x256/versions/1/celeba_hq_256",
        )
        dataset = CelebAHQDataset(root_dir=celeba_dir)
    elif dataset_name == "afhq":
        # AFHQ needs to be downloaded separately and path provided
        afhq_dir = kwargs.get("afhq_dir", f"{dataroot}/data/afhq")
        if not os.path.exists(afhq_dir):
            raise ValueError(
                f"AFHQ dataset not found at {afhq_dir}. Please download it first."
            )

        transform = transforms.Compose(
            [
                transforms.Resize((64, 64)),
                transforms.ToTensor(),
                transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
            ]
        )
        dataset = datasets.ImageFolder(root=afhq_dir, transform=transform)
    else:
        raise ValueError(f"Unknown dataset: {dataset_name}")

    # Create subset if num_images specified
    if num_images is not None and num_images > 0:
        indices = torch.randperm(len(dataset), generator=g)[:num_images]
        dataset = torch.utils.data.Subset(dataset, indices)

    # Create and return DataLoader
    return DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=4,
        pin_memory=True,
        generator=g,
        drop_last=True,
    )
