import torch
from torch.utils.data import DataLoader, Dataset, Subset
import torchvision.transforms as transforms
from torchvision.datasets import CelebA, CIFAR10
import numpy as np

class CelebADataset(Dataset):
    """
    A custom dataset wrapper for CelebA for binary classification.
    The target attribute is 'Smiling' and the protected attribute is 'Male' by default.
    """
    def __init__(self, root, split='train', target_attr='Smiling', protected_attr='Male', transform=None):
        # Download the CelebA dataset to 'root' if it's not already available.
        self.dataset = CelebA(root=root, split=split, download=True, transform=transform)
        self.target_attr = target_attr
        self.protected_attr = protected_attr
        # Find attribute indices (CelebA attributes are stored as a list in attr_names)
        self.target_idx = self.dataset.attr_names.index(target_attr)
        self.protected_idx = None if protected_attr is None else self.dataset.attr_names.index(protected_attr)

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, index):
        img, attr = self.dataset[index]
        # Convert attributes from {-1, 1} to {0, 1}
        target = (attr[self.target_idx].item() + 1) // 2
        if self.protected_attr is not None:
            group = (attr[self.protected_idx].item() + 1) // 2
            return img, target, group
        else:
            return img, target

class CIFAR10Binary(Dataset):
    """
    A custom dataset wrapper for CIFAR10 for binary classification.
    The target classes are specified by `class_pos` and `class_neg`.
    Supports imbalance by downsampling the positive class.
    """
    def __init__(self, root, split='train', class_pos=0, class_neg=9, imbalance_ratio=1.0,val_split=0.1, transform=None,seed=42):
        self.dataset = CIFAR10(root=root, train=(split == 'train'), download=True, transform=transform)
        self.class_pos = class_pos
        self.class_neg = class_neg
        self.transform = transform
        
        if class_pos is None or class_neg is None:
            raise ValueError("Both class_pos and class_neg must be specified.")
        
        # Filter indices for the selected classes
        pos_indices = [i for i, label in enumerate(self.dataset.targets) if label == class_pos]
        neg_indices = [i for i, label in enumerate(self.dataset.targets) if label == class_neg]
        
        # Apply imbalance if necessary
        if imbalance_ratio < 1.0:
            np.random.seed(seed)
            np.random.shuffle(pos_indices)
            pos_indices = pos_indices[:int(len(pos_indices) * imbalance_ratio)]
        
        self.indices = pos_indices + neg_indices
        np.random.seed(seed)
        np.random.shuffle(self.indices)

        # Split test set into validation and test
        if split == 'valid':
            np.random.seed(seed)
            val_size = int(len(self.indices) * val_split)
            self.indices = self.indices[:val_size]
        elif split == 'test':
            np.random.seed(seed)
            val_size = int(len(self.indices) * val_split)
            self.indices = self.indices[val_size:]

    def __len__(self):
        return len(self.indices)

    def __getitem__(self, idx):
        img, label = self.dataset[self.indices[idx]]
        label = 1 if label == self.class_pos else 0
        return img, label

def get_dataloader(config, split='train', seed=42):
    """
    Unified dataloader getter function to align with existing implementation.
    """
    # Define transforms
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
    ])
    rng = np.random.default_rng(seed)
    dataset_name = config['name']
    
    match dataset_name:
        case 'CelebA':
            root = "./data/celeba"
            target_attr = config.get('target_attr', 'Smiling')
            protected_attr = config.get('protected_attr', None)
            dataset = CelebADataset(root, split=split, target_attr=target_attr,
                                    protected_attr=protected_attr, transform=transform)
        case 'CIFAR10Binary':
            root = "./data/cifar10"
            class_pos = config.get('class_pos',0)
            class_neg = config.get('class_neg',9)
            imbalance_ratio = config.get('imbalance_ratio', 1.0)
            dataset = CIFAR10Binary(root, split=split, class_pos=class_pos,
                                    class_neg=class_neg, imbalance_ratio=imbalance_ratio,
                                    transform=transform)
        case _:
            raise NotImplementedError('Dataset not implemented')
    
    if split == 'train' and 'train_n' in config:
        indices = rng.choice(len(dataset), size=config['train_n'], replace=False)
        subset = Subset(dataset, indices)
    if split == 'valid' and 'valid_n' in config:
        indices = rng.choice(len(dataset), size=config['valid_n'], replace=False)
        subset = Subset(dataset, indices)
    else:
        subset = dataset

    dataloader = DataLoader(subset,
                            batch_size=config['batch_size'],
                            shuffle=(split == 'train' or split == 'valid'),
                            num_workers=config['num_workers'])
    return dataloader
