import math
import copy
import random
import os
import pickle
from PIL import Image
from typing import Tuple, Optional
import numpy as np
from torch.utils.data import Subset

import torch
import torchvision
import torchvision.transforms.v2 as T
from torch.utils.data import Dataset, random_split, ConcatDataset

from library import configs
from library import misc
# from library import mnist_dataset
from library import uci_datasets

class TransformedDataset(Dataset):
    # Different transforms for validation dataset, but is part of train dataset
    def __init__(self, subset, transform):
        self.subset = subset
        self.transform = transform

    def __getitem__(self, idx):
        x, y = self.subset[idx]
        return self.transform(x), y

    def __len__(self):
        return len(self.subset)

class BinaryfyLabel(torch.utils.data.Dataset):
    def __init__(self, dataset, n, num_classes=10):
        self.dataset = dataset
        self.n = n
        self.num_classes = num_classes
        assert n % num_classes == 0, "n must be divisible by number of classes"
        self.block_size = n // num_classes

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        x, label = self.dataset[idx]
        y = torch.zeros(self.n)
        start = label * self.block_size
        y[start:start + self.block_size] = 1.0
        return x, y

class BinarizeInput(object):
    def __init__(self, threshold=0.5):
        self.threshold = threshold

    def __call__(self, tensor):
        return (tensor > self.threshold).float()

    def __repr__(self):
        return f"{self.__class__.__name__}(threshold={self.threshold})"

class AddGaussianNoise(object):
    def __init__(self, mean=0., std=0.1, clamp=True):
        self.mean = mean
        self.std = std
        self.clamp = clamp

    def __call__(self, tensor):
        noisy = tensor + torch.randn(tensor.size()) * self.std + self.mean
        if self.clamp:
            noisy = torch.clamp(noisy, 0., 1.)
        return noisy

    def __repr__(self):
        return (self.__class__.__name__ + 
                f'(mean={self.mean}, std={self.std}, clamp={self.clamp})')

def load_dataset(config: configs.DifflogicConfig) -> Tuple[torch.utils.data.DataLoader, torch.utils.data.DataLoader, torch.utils.data.DataLoader]:
    
    data_config = config.data_config

    train_batch_size = config.train_config.batch_size
    test_batch_size = config.test_config.batch_size

    train_transform_list = []
    valid_transform_list = []
    test_transform_list = []

    if data_config.upscale_input != 0:
        test_transform_list.append(T.Resize((data_config.upscale_input, data_config.upscale_input)))
        valid_transform_list.append(T.Resize((data_config.upscale_input, data_config.upscale_input)))
        train_transform_list.append(T.Resize((data_config.upscale_input, data_config.upscale_input)))

    if data_config.augmentation:
        train_transform_list.extend([
            T.RandomAffine(
                degrees=3, # before 10
                translate=(0.05, 0.05),   # small horizontal/vertical shifts # before (0.1, 0.1)
                scale=(0.95, 1.05),       # slight zoom in/out # before (0.9, 1.1)
                shear=2                 # mild shear # before 5
            ),
            T.ToTensor(),
            AddGaussianNoise(0, 0.02), # before 0.05
        ])
    else:
        train_transform_list.append(T.ToTensor())

    test_transform_list.append(T.ToTensor())
    valid_transform_list.append(T.ToTensor())
    
    if data_config.binarize_input_train != 0:
        train_transform_list.append(BinarizeInput(threshold=data_config.binarize_input_train))

    if data_config.eval_binarized != 0:
        valid_bin_transform_list = copy.deepcopy(valid_transform_list)
        valid_bin_transform_list.append(BinarizeInput(threshold=data_config.eval_binarized))
        valid_bin_transform = T.Compose(valid_bin_transform_list)

        test_bin_transform_list = copy.deepcopy(test_transform_list)
        test_bin_transform_list.append(BinarizeInput(threshold=data_config.eval_binarized))
        test_bin_transform = T.Compose(test_bin_transform_list)
    else: 
        valid_bin_transform = None
        test_bin_transform = None
        validation_bin_loader = None
    
    train_transform = T.Compose(train_transform_list)
    test_transform = T.Compose(test_transform_list)
    valid_transform = T.Compose(valid_transform_list)
    
    if data_config.dataset == configs.Dataset.ADULT:
        raise NotImplementedError(f"This dataset is not yet fully supported")
        data_dir = os.path.join(data_config.data_dir, 'uci')
        train_set = uci_datasets.AdultDataset(data_dir, split='train', download=data_config.download, with_val=False)
        test_set = uci_datasets.AdultDataset(data_dir, split='test', with_val=False)
        
    elif data_config.dataset == configs.Dataset.BREAST_CANCER:
        raise NotImplementedError(f"This dataset is not yet fully supported")
        data_dir = os.path.join(data_config.data_dir, 'uci')
        train_set = uci_datasets.BreastCancerDataset(data_dir, split='train', download=data_config.download, with_val=False)
        test_set = uci_datasets.BreastCancerDataset(data_dir, split='test', with_val=False)
        
    elif data_config.dataset.value.startswith('monk'):
        raise NotImplementedError(f"This dataset is not yet fully supported")
        data_dir = os.path.join(data_config.data_dir, 'uci')
        style = int(data_config.dataset[4])
        train_set = uci_datasets.MONKsDataset(data_dir, style, split='train', download=data_config.download, with_val=False)
        test_set = uci_datasets.MONKsDataset(data_dir, style, split='test', with_val=False)
        
    elif data_config.dataset.value.startswith('mnist'):
        data_dir = os.path.join(data_config.data_dir, 'mnist')
        full_train_dataset = torchvision.datasets.MNIST(data_dir, train=True, download=data_config.download, transform=None)
        test_set = torchvision.datasets.MNIST(data_dir, train=False, download=data_config.download, transform=test_transform)
        test_set_bin = torchvision.datasets.MNIST(data_dir, train=False, download=data_config.download, transform=test_bin_transform)
    ########## Added Datasets ###########
    elif data_config.dataset.value.startswith('fmnist'):
        data_dir = os.path.join(data_config.data_dir, 'fmnist')
        full_train_dataset = torchvision.datasets.FashionMNIST(data_dir, train=True, download=data_config.download, transform=None)
        test_set = torchvision.datasets.FashionMNIST(data_dir, train=False, download=data_config.download, transform=test_transform)
        test_set_bin = torchvision.datasets.FashionMNIST(data_dir, train=False, download=data_config.download, transform=test_bin_transform)
    elif data_config.dataset.value.startswith('kmnist'):
        data_dir = os.path.join(data_config.data_dir, 'kmnist')
        full_train_dataset, raw_test_set = get_kmnist_dataset(data_config)
        test_set = TransformedDataset(raw_test_set, test_transform)
        test_set_bin = TransformedDataset(raw_test_set, test_bin_transform)
    elif data_config.dataset.value.startswith('qmnist'):
        data_dir = os.path.join(data_config.data_dir, 'qmnist')
        full_train_dataset = torchvision.datasets.QMNIST(data_dir, train=True, download=data_config.download, transform=None)
        test_set = torchvision.datasets.QMNIST(data_dir, train=False, download=data_config.download, transform=test_transform)
        test_set_bin = torchvision.datasets.QMNIST(data_dir, train=False, download=data_config.download, transform=test_bin_transform)
    elif data_config.dataset.value.startswith('emnist'):
        data_dir = os.path.join(data_config.data_dir, 'emnist')
        if data_config.dataset.value.startswith('emnist_balanced'):
            full_train_dataset = torchvision.datasets.EMNIST(data_dir, split="balanced", train=True, download=data_config.download, transform=None)
            test_set = torchvision.datasets.EMNIST(root=data_dir, split="balanced", train=False, download=data_config.download, transform=test_transform)
            test_set_bin = torchvision.datasets.EMNIST(root=data_dir, split="balanced", train=False, download=data_config.download, transform=test_bin_transform)
        elif data_config.dataset.value.startswith('emnist_letters'):
            full_train_dataset = torchvision.datasets.EMNIST(data_dir, split="letters", train=True, download=data_config.download, transform=None, target_transform=lambda y: y - 1)
            test_set = torchvision.datasets.EMNIST(root=data_dir, split="letters", train=False, download=data_config.download, transform=test_transform, target_transform=lambda y: y - 1)
            test_set_bin = torchvision.datasets.EMNIST(root=data_dir, split="letters", train=False, download=data_config.download, transform=test_bin_transform, target_transform=lambda y: y - 1)
    elif data_config.dataset.value.startswith('custom'):
        if data_config.dataset.value == 'custom_imagenet':
            full_train_dataset, raw_test_set, transforms = get_custom_imagenet(data_config)
            train_transform, test_transform, valid_transform, valid_bin_transform = transforms
            test_set = TransformedDataset(raw_test_set, test_transform)
            test_set_bin = None
        else:
            full_train_dataset, raw_test_set = get_custom_dataset_2(data_config)
            test_set = TransformedDataset(raw_test_set, test_transform)
            test_set_bin = TransformedDataset(raw_test_set, test_bin_transform)
    elif data_config.dataset.value.startswith('synthetic'):
        try:
            lower_bound_fixed = data_config.lower_bound_fixed
        except:
            lower_bound_fixed = data_config.fixed_bits_per_class
        full_train_dataset, test_set = create_synthetic_binary_datasets(num_classes=data_config.num_classes,
                                                                        samples_per_class_train=data_config.samples_per_class_train,
                                                                        samples_per_class_test=data_config.samples_per_class_test,
                                                                        input_size=data_config.input_size, # 784
                                                                        fixed_bits_per_class=data_config.fixed_bits_per_class,
                                                                        lower_bound_fixed=lower_bound_fixed
                                                                        )
        
        test_set = TransformedDataset(test_set, test_transform)
        test_set_bin = None
    ####################################

    elif data_config.dataset == configs.Dataset.CIFAR10:
        full_train_dataset, test_set, transforms = get_cifar_dataset(
            data_config,
            [train_transform_list, test_transform_list, valid_transform_list, valid_bin_transform_list]
        )
        train_transform, test_transform, valid_transform, valid_bin_transform = transforms
        test_set = TransformedDataset(test_set, test_transform)
        test_set_bin = None
    
    elif data_config.dataset == configs.Dataset.CIFAR100:
        full_train_dataset, test_set, transforms = get_cifar100_dataset(
            data_config,
            [train_transform_list, test_transform_list, valid_transform_list, valid_bin_transform_list]
        )
        train_transform, test_transform, valid_transform, valid_bin_transform = transforms
        test_set = TransformedDataset(test_set, test_transform)
        test_set_bin = None
    elif data_config.dataset == configs.Dataset.IMAGENET32:
        full_train_dataset, test_set, transforms = get_imagenet32_dataset(
            data_config,
            [train_transform_list, test_transform_list, valid_transform_list, valid_bin_transform_list]
        )
        train_transform, test_transform, valid_transform, valid_bin_transform = transforms
        # test_set = TransformedDataset(test_set, test_transform)
        test_set_bin = None
    else:
        raise NotImplementedError(f'The data set {data_config.dataset} is not supported!')

    train_size = int((1 - data_config.valid_set_size) * len(full_train_dataset))
    valid_size = len(full_train_dataset) - train_size
    ##################################
    # print(f"n={config.model_config.last_layer_neurons}, n_c={data_config.num_classes}")
    # full_train_dataset = BinaryfyLabel(full_train_dataset, n=config.model_config.last_layer_neurons, num_classes=data_config.num_classes)
    # test_set = BinaryfyLabel(test_set, n=config.model_config.last_layer_neurons, num_classes=data_config.num_classes)
    ##################################
    if data_config.valid_set_size > 0:
        train_subset, valid_subset = random_split(full_train_dataset, [train_size, valid_size])
        train_dataset = TransformedDataset(train_subset, train_transform)
        valid_dataset = TransformedDataset(valid_subset, valid_transform)
        if data_config.eval_binarized != 0:
            valid_bin_dataset = TransformedDataset(valid_subset, valid_bin_transform)
            validation_bin_loader = torch.utils.data.DataLoader(valid_bin_dataset, 
                                                        batch_size=test_batch_size,
                                                        shuffle=False, 
                                                        pin_memory=data_config.pin_memory, 
                                                        drop_last=False)
        train_loader = torch.utils.data.DataLoader(train_dataset, 
                                                   batch_size=train_batch_size, 
                                                   shuffle=True, 
                                                   pin_memory=data_config.pin_memory, 
                                                   drop_last=True, 
                                                   num_workers=data_config.num_workers)
        test_loader = torch.utils.data.DataLoader(test_set, 
                                                  batch_size=test_batch_size, 
                                                  shuffle=False, 
                                                  pin_memory=data_config.pin_memory, 
                                                  drop_last=False, 
                                                  num_workers=data_config.num_workers)
        validation_loader = torch.utils.data.DataLoader(valid_dataset, 
                                                        batch_size=test_batch_size, 
                                                        shuffle=False, 
                                                        pin_memory=data_config.pin_memory, 
                                                        drop_last=False)
        if test_set_bin is not None:
            test_loader_bin = torch.utils.data.DataLoader(test_set_bin, 
                                                    batch_size=test_batch_size, 
                                                    shuffle=False, 
                                                    pin_memory=data_config.pin_memory, 
                                                    drop_last=False, 
                                                    num_workers=data_config.num_workers)
        else:
            test_loader_bin = None
    else:
        pass
        # TODO
        # train_subset = full_train_dataset

    return train_loader, validation_loader, test_loader, validation_bin_loader, test_loader_bin


def load_n(loader: torch.utils.data.DataLoader, n: int):
    i = 0
    while i < n:
        for x in loader:
            yield x
            i += 1
            if i == n:
                break

def get_custom_dataset_2(data_config):
    data_dir = data_config.data_dir
    num_classes = data_config.num_classes
    assert 2 <= num_classes <= 67, "Custom supports 2 to 67 classes."
    random.seed(data_config.seed)

    datasets = [
        (torchvision.datasets.FashionMNIST, "fmnist", 0, 10),
        (torchvision.datasets.KMNIST, "kmnist", 10, 10),
        (lambda root, train, download, transform: torchvision.datasets.EMNIST(
            root, split="balanced", train=train, download=download, transform=None
        ), "emnist_balanced", 20, 47),
    ]

    # Create a pool of all possible (dataset, class_id) pairs
    available_classes = []
    for loader_fn, subdir, offset, max_cls in datasets:
        for cls in range(max_cls):
            available_classes.append((loader_fn, subdir, offset, cls))

    # Randomly choose 'num_classes' unique (dataset, class_id) pairs
    chosen_classes = random.sample(available_classes, num_classes)

    # Group chosen classes by dataset for efficient loading
    from collections import defaultdict
    grouped = defaultdict(list)
    for loader_fn, subdir, offset, cls in chosen_classes:
        grouped[(loader_fn, subdir, offset)].append(cls)

    all_train_samples = []
    all_test_samples = []
    label_mapping = {}
    to_tensor = T.ToTensor()
    current_label = 0

    for (loader_fn, subdir, offset), class_list in grouped.items():
        root = os.path.join(data_dir, subdir)

        if subdir == "kmnist":
            train_dataset, test_dataset = get_kmnist_dataset(data_config)
        else:
            train_dataset = loader_fn(root=root, train=True, download=data_config.download, transform=None)
            test_dataset = loader_fn(root=root, train=False, download=data_config.download, transform=None)

        for cls in class_list:
            # Remap to global label
            label_mapping[(subdir, cls)] = current_label

            # Training samples
            class_train_samples = [(img, current_label) for img, label in train_dataset if label == cls]
            all_train_samples.extend(class_train_samples)

            # Test samples
            class_test_samples = [(img, current_label) for img, label in test_dataset if label == cls]
            all_test_samples.extend(class_test_samples)

            current_label += 1

    random.shuffle(all_train_samples)
    random.shuffle(all_test_samples)

    train_images, train_labels = zip(*all_train_samples)
    test_images, test_labels = zip(*all_test_samples)

    full_train_dataset = torch.utils.data.TensorDataset(
        torch.stack([to_tensor(img) for img in train_images]),
        torch.tensor(train_labels)
    )

    test_set = torch.utils.data.TensorDataset(
        torch.stack([to_tensor(img) for img in test_images]),
        torch.tensor(test_labels)
    )

    return full_train_dataset, test_set

def get_custom_imagenet(data_config):
    num_classes = data_config.num_classes
    assert 2 <= num_classes <= 1000, "ImageNet32 has 1000 classes."
    random.seed(data_config.seed)

    data_dir = os.path.join(data_config.data_dir, 'imagenet32')
    image_thresholds = data_config.image_thresholds

    transform_list = [
        T.ToImage(),
        T.ToDtype(torch.float32, scale=True),
        T.Lambda(lambda x: torch.cat([(x > (i + 1) / (image_thresholds + 1)).float() for i in range(image_thresholds)], dim=0))
    ]
    transform = T.Compose(transform_list)

    # Load datasets
    train_files = [os.path.join(data_dir, f'imagenet32_train/train_data_batch_{i}') for i in range(1, 11)]
    test_file = [os.path.join(data_dir, 'imagenet32_val/val_data')]

    full_train_dataset = ImageNet32Dataset(train_files, transform=None)
    full_test_dataset = ImageNet32Dataset(test_file, transform=None)

    # Class filtering and mapping
    all_classes = list(set(full_train_dataset.labels))
    print("Available classes:", sorted(all_classes))
    print("Number of classes found:", len(all_classes))
    selected_classes = set(random.sample(all_classes, num_classes))
    class_map = {orig: idx for idx, orig in enumerate(selected_classes)}

    def filter_and_transform(dataset, loc_transform=T.ToTensor()):
        filtered_imgs, new_labels = [], []
        for img, label in dataset:
            if label in selected_classes:
                filtered_imgs.append(loc_transform(img))
                new_labels.append(class_map[label])
        return torch.utils.data.TensorDataset(torch.stack(filtered_imgs), torch.tensor(new_labels))

    train_dataset = filter_and_transform(full_train_dataset)
    test_dataset = filter_and_transform(full_test_dataset)

    return train_dataset, test_dataset, [transform, transform, transform, transform]

def get_kmnist_dataset(data_config):
    data_dir = os.path.join(data_config.data_dir, 'kmnist')

    original_train = torchvision.datasets.KMNIST(data_dir, train=True, download=data_config.download, transform=None)
    original_test = torchvision.datasets.KMNIST(data_dir, train=False, download=data_config.download, transform=None)
    full_dataset = ConcatDataset([original_train, original_test])

    total_size = len(full_dataset)

    train_size = len(original_train)
    test_size = total_size - train_size

    generator = torch.Generator().manual_seed(data_config.seed)
    indices = torch.randperm(total_size, generator=generator)

    train_indices = indices[:train_size]
    test_indices = indices[train_size:]

    full_train_dataset = torch.utils.data.Subset(full_dataset, train_indices)
    test_set = torch.utils.data.Subset(full_dataset, test_indices)

    return full_train_dataset, test_set

def get_cifar_dataset(data_config, transform_lists):
    data_dir = os.path.join(data_config.data_dir, 'cifar')
    image_thresholds = data_config.image_thresholds
    transforms = []
    for lst in transform_lists:
        lst = [
            T.ToImage(),
            T.ToDtype(torch.float32, scale=True),
            T.Lambda(lambda x: torch.cat([(x > (i + 1) / (image_thresholds + 1)).float() for i in range(image_thresholds)], dim=0))
        ]

        transforms.append(T.Compose(lst))
    full_train_dataset = torchvision.datasets.CIFAR10(data_dir, train=True, download=data_config.download, transform=None)
    # test_set = torchvision.datasets.CIFAR10(data_dir, train=False, download=data_config.download, transform=transforms[1])
    test_set = torchvision.datasets.CIFAR10(data_dir, train=False, download=data_config.download, transform=None)

    return full_train_dataset, test_set, transforms
"""

def get_cifar_dataset(data_config, transform_lists):
    data_dir = os.path.join(data_config.data_dir, 'cifar')
    image_thresholds = data_config.image_thresholds
    transforms = []

    for lst in transform_lists:
        full_lst = list(lst) + [  # copy and extend safely
            T.ToImage(),
            T.ToDtype(torch.float32, scale=True),
            T.Lambda(lambda x: torch.cat([(x > (i + 1) / (image_thresholds + 1)).float() for i in range(image_thresholds)], dim=0))
        ]
        transforms.append(T.Compose(full_lst))

    full_train_dataset = torchvision.datasets.CIFAR10(data_dir, train=True, download=data_config.download, transform=None)
    test_set = torchvision.datasets.CIFAR10(data_dir, train=False, download=data_config.download, transform=None)

    return full_train_dataset, test_set, transforms
"""

def get_cifar100_dataset(data_config, transform_lists):
    data_dir = os.path.join(data_config.data_dir, 'cifar100')
    image_thresholds = data_config.image_thresholds
    transforms = []
    for lst in transform_lists:
        lst = [
            T.ToImage(),
            T.ToDtype(torch.float32, scale=True),
            T.Lambda(lambda x: torch.cat([(x > (i + 1) / (image_thresholds + 1)).float() for i in range(image_thresholds)], dim=0))
        ]

        transforms.append(T.Compose(lst))
    full_train_dataset = torchvision.datasets.CIFAR100(data_dir, train=True, download=data_config.download, transform=None)
    # test_set = torchvision.datasets.CIFAR100(data_dir, train=False, download=data_config.download, transform=transforms[1])
    test_set = torchvision.datasets.CIFAR100(data_dir, train=False, download=data_config.download, transform=None)

    return full_train_dataset, test_set, transforms

def get_noisy_dataloader(dataloader: torch.utils.data.DataLoader, noise_level: float, salt_and_pepper: bool = False) -> torch.utils.data.DataLoader:
    """
    Creates a new dataloader that applies Gaussian noise to images from the original dataloader.
    
    Args:
        dataloader: Original dataloader
        noise_level: Standard deviation of the Gaussian noise to add
        
    Returns:
        A new dataloader with noisy images
    """
    class NoiseTransform(torch.utils.data.Dataset):
        def __init__(self, dataset, noise_level: float, salt_and_pepper: bool = False):
            self.dataset = dataset
            assert 0 <= noise_level
            self.noise_level = noise_level
            self.salt_and_pepper = salt_and_pepper
            if not salt_and_pepper:
                self.transform = T.Compose([
                    T.Lambda(lambda x: x + torch.randn_like(x) * noise_level),
                    T.Lambda(lambda x: torch.clamp(x, 0, 1))
                ])
            else:
                assert 0 <= noise_level <= 1
            
        def __len__(self):
            return len(self.dataset)
            
        def __getitem__(self, idx):
            # Get original item with all existing transforms applied
            image, label = self.dataset[idx]
            if not self.salt_and_pepper:
                noisy_image = self.transform(image)
            else:
                noisy_image = image.clone()
                mask = torch.rand_like(image)
                noisy_image[mask < noise_level/2] = 0
                noisy_image[mask > 1 - noise_level/2] = 1
            
            return noisy_image, label
    
    # Create a noisy version of the dataset
    noisy_dataset = NoiseTransform(dataloader.dataset, noise_level, salt_and_pepper=salt_and_pepper)
    
    # Create new dataloader with the noisy dataset
    return torch.utils.data.DataLoader(
        noisy_dataset,
        batch_size=dataloader.batch_size,
        shuffle=isinstance(dataloader.sampler, torch.utils.data.sampler.RandomSampler),
        num_workers=dataloader.num_workers,
        drop_last=getattr(dataloader, 'drop_last', False)
    )


# Custom dataset class compatible with PyTorch
class SyntheticBinaryDataset(Dataset):
    def __init__(self, data: torch.Tensor, labels: torch.Tensor):
        self.data = data
        self.labels = labels

    def __len__(self):
        return self.data.shape[0]

    def __getitem__(self, idx):
        return self.data[idx], self.labels[idx]

# Generate fixed bit positions for a given class (deterministic using class_index as seed)
def fixed_positions_for_class(class_index: int, input_size: int, num_bits: int) -> np.ndarray:
    rng = np.random.RandomState(class_index)  # Seeded with class index
    return rng.choice(input_size, size=num_bits, replace=False)  # Choose unique positions

# Generate fixed bit values for a given class (different seed from positions)
def fixed_pattern_for_class(class_index: int, num_bits: int) -> torch.Tensor:
    rng = np.random.RandomState(class_index + 10_000)  # Offset seed to avoid correlation with positions
    return torch.tensor(rng.randint(0, 2, size=num_bits), dtype=torch.uint8)  # Random 0/1 values

# Main function to create both train and test datasets
def create_synthetic_binary_datasets(
    num_classes: int,
    samples_per_class_train: int,
    samples_per_class_test: int,
    input_size: int = 784,
    fixed_bits_per_class: int = 100,
    lower_bound_fixed: int = 100
) -> Tuple[Dataset, Dataset]:

    assert fixed_bits_per_class <= input_size, "fixed_bits_per_class must be <= input_size"

    # Inner function to generate one dataset (train or test)
    def generate_dataset(samples_per_class, num_fixed=None):
        data = []
        labels = []

        if num_fixed == None:
            fixeds = [None for _ in range(num_classes)]
        else:
            fixeds = num_fixed

        # Loop over all classes
        for class_label in range(num_classes):
            # Get fixed bit positions and values for this class

            if fixeds[class_label] == None:
                fixed = np.random.randint(lower_bound_fixed, fixed_bits_per_class + 1)
                fixeds[class_label] = fixed
            else:
                fixed = fixeds[class_label]
            fixed_positions = fixed_positions_for_class(class_label, input_size, fixed)
            fixed_pattern = fixed_pattern_for_class(class_label, fixed)

            # Create samples for this class
            for _ in range(samples_per_class):
                # Random binary sample
                sample = torch.randint(0, 2, (input_size,), dtype=torch.uint8)
                # Set class-specific fixed bits
                sample[fixed_positions] = fixed_pattern
                data.append(sample)
                labels.append(class_label)

        # Stack into tensors
        data = torch.stack(data).float()  # Convert to float for model input
        labels = torch.tensor(labels)
        return SyntheticBinaryDataset(data, labels), fixeds

    # Generate train and test datasets
    train_dataset, fixeds = generate_dataset(samples_per_class_train)
    test_dataset, _ = generate_dataset(samples_per_class_test, num_fixed=fixeds)

    return train_dataset, test_dataset
"""
def create_synthetic_binary_datasets(
    num_classes: int,
    samples_per_class_train: int,
    samples_per_class_test: int,
    input_size: int = 784,
    fixed_bits_per_class: int = 100,
    lower_bound_fixed: int = 100,
    seed: int = 42
) -> Tuple[Dataset, Dataset]:
    assert fixed_bits_per_class <= input_size, "fixed_bits_per_class must be <= input_size"

    rng = np.random.RandomState(seed)

    # 🔁 Precompute class-wise fixed value (used consistently across train/test)
    class_configs = {}
    for class_label in range(num_classes):
        # Sample number of fixed bits *once* per class
        fixed = rng.randint(lower_bound_fixed, fixed_bits_per_class + 1)
        positions = fixed_positions_for_class(class_label, input_size, fixed)
        pattern = fixed_pattern_for_class(class_label, fixed)
        class_configs[class_label] = (positions, pattern)

    def generate_dataset(samples_per_class):
        data = []
        labels = []
        for class_label in range(num_classes):
            positions, pattern = class_configs[class_label]
            for _ in range(samples_per_class):
                sample = torch.randint(0, 2, (input_size,), dtype=torch.uint8)
                sample[positions] = pattern
                data.append(sample)
                labels.append(class_label)
        data = torch.stack(data).float()
        labels = torch.tensor(labels)
        return SyntheticBinaryDataset(data, labels)

    train_dataset = generate_dataset(samples_per_class_train)
    test_dataset = generate_dataset(samples_per_class_test)

    return train_dataset, test_dataset
"""

def create_classwise_pixel_distribution_datasets(
    hist_vector: list,
    num_classes: int,
    samples_per_class_train: int,
    samples_per_class_test: int,
    input_size: int = 784,
    seed: Optional[int] = None
) -> Tuple[Dataset, Dataset]:

    hist_vector = np.array(hist_vector, dtype=np.float32)
    hist_vector = hist_vector / hist_vector.sum()  # Normalize to sum to 1
    bin_edges = np.linspace(0, 1, len(hist_vector) + 1)  # e.g., [0.0, 0.1, ..., 1.0]

    def generate_class_pixel_probs() -> torch.Tensor:
        """Generate class-wise pixel activation probabilities.
        Returns:
            probs: Tensor of shape [num_classes, input_size]
        """
        probs = []
        for class_id in range(num_classes):
            class_probs = torch.zeros(input_size)
            # Decide how many pixels go in each activity bin
            pixels_per_bin = (hist_vector * input_size).astype(int)
            gap = input_size - pixels_per_bin.sum()
            if gap > 0:
                pixels_per_bin[:gap] += 1

            pixel_indices = np.random.permutation(input_size)
            start = 0
            for bin_idx, count in enumerate(pixels_per_bin):
                if count == 0:
                    continue
                min_p = bin_edges[bin_idx]
                max_p = bin_edges[bin_idx + 1]
                p = np.random.uniform(min_p, max_p)
                selected = pixel_indices[start:start+count]
                class_probs[selected] = p
                start += count
            probs.append(class_probs)
        return torch.stack(probs)  # [num_classes, input_size]

    def generate_dataset(samples_per_class, class_pixel_probs):
        data = []
        labels = []
        for class_id in range(num_classes):
            probs = class_pixel_probs[class_id]  # [input_size]
            samples = torch.bernoulli(probs.repeat(samples_per_class, 1))
            data.append(samples)
            labels.extend([class_id] * samples_per_class)

        data = torch.cat(data, dim=0).float()  # [N, input_size]
        labels = torch.tensor(labels)
        return SyntheticBinaryDataset(data, labels)

    class_pixel_probs = generate_class_pixel_probs()

    train_dataset = generate_dataset(samples_per_class_train, class_pixel_probs)
    test_dataset = generate_dataset(samples_per_class_test, class_pixel_probs)

    return train_dataset, test_dataset

class ImageNet32Dataset(Dataset):
    def __init__(self, batch_files, transform=None):
        self.data = []  # list of (image array, label)
        self.labels = []  # <- You should add this if you need class info later
        self.transform = transform

        for file in batch_files:
            with open(file, 'rb') as f:
                batch = pickle.load(f, encoding='latin1')
            data = np.array(batch['data']).reshape(-1, 3, 32, 32)
            labels = np.array(batch['labels']) - 1  # 0-based labels

            for i in range(len(data)):
                self.data.append((data[i], labels[i]))
                self.labels.append(labels[i])  # <- This line is critical

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        img, label = self.data[idx]
        img = np.transpose(img, (1, 2, 0))  # CHW to HWC
        img = Image.fromarray(img.astype(np.uint8))
        if self.transform:
            img = self.transform(img)
        return img, label

# Equivalent loader function
def get_imagenet32_dataset(data_config, transform_lists):
    data_dir = os.path.join(data_config.data_dir, 'imagenet32')

    # Define transforms
    image_thresholds = data_config.image_thresholds
    transforms = []
    for lst in transform_lists:
        lst = [
            T.ToTensor(),
            T.Lambda(lambda x: torch.cat([
                (x > (i + 1) / (image_thresholds + 1)).float()
                for i in range(image_thresholds)
            ], dim=0))
        ]
        transforms.append(T.Compose(lst))

    # List files
    train_files = [os.path.join(data_dir, f'imagenet32_train/train_data_batch_{i}') for i in range(1, 11)]
    val_file = [os.path.join(data_dir, 'imagenet32_val/val_data')]

    # Datasets
    full_train_dataset = ImageNet32Dataset(train_files, transform=None)
    test_set = ImageNet32Dataset(val_file, transform=transforms[1])

    return full_train_dataset, test_set, transforms