import torch
from torch.utils.data import DataLoader, random_split, Dataset
from torchvision import datasets, transforms
import os
import torch.nn.functional as F
import numpy as np
from PIL import Image
# from .options import args_parser


# Ensure the directory exists
dataset_path = '/directory/datasets'
os.makedirs(dataset_path, exist_ok=True)

# args = args_parser()

num_workers = 4
lp=0
noise_epsilon=0.3
# parser.add_argument('--lp', type=int, default=0, help='0 (Gaussian) 1 (L1) 2 (L2) 3 (L_inf) noise')
# parser.add_argument('--noise_epsilon', type=float, default=0.1, help='Variance for Gaussian, and ball radius for other noises')

class CIFAR10:
    def __init__(self, batch_size=32, random_seed=42, data_randseed=24):
        self.batch_size = batch_size
        self.random_seed = random_seed
        self.data_randseed = data_randseed
        self.g = self.set_seeds()

    def set_seeds(self):
        torch.manual_seed(self.random_seed)
        if torch.cuda.is_available():
            torch.cuda.manual_seed_all(self.random_seed)
        g = torch.Generator()
        g.manual_seed(self.data_randseed)
        return g

    def load_datasets(self):
        transform = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.RandomHorizontalFlip(),
            transforms.RandomCrop(224, padding=4),
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))  # CIFAR-10 normalization
        ])

        cifar10_trainval = datasets.CIFAR10(root=dataset_path, train=True, download=True, transform=transform)
        cifar10_test = datasets.CIFAR10(root=dataset_path, train=False, download=True, transform=transform)

        train_size = int(0.8 * len(cifar10_trainval))
        val_size = len(cifar10_trainval) - train_size
        cifar10_train, cifar10_val = random_split(cifar10_trainval, [train_size, val_size])

        train_loader = DataLoader(cifar10_train, batch_size=self.batch_size, shuffle=True, num_workers=num_workers, generator=self.g)
        val_loader = DataLoader(cifar10_val, batch_size=self.batch_size, shuffle=False, num_workers=num_workers)
        test_loader = DataLoader(cifar10_test, batch_size=self.batch_size, shuffle=False, num_workers=num_workers)

        return train_loader, val_loader, test_loader

class CIFAR10_Lp(CIFAR10):
    def __init__(self, batch_size=32, random_seed=42, data_randseed=24, lp=lp, noise_epsilon=noise_epsilon):
        super().__init__(batch_size, random_seed, data_randseed)
        self.lp = lp
        self.noise_epsilon = noise_epsilon

    def add_noise(self, x):
        if self.lp == 0:  # Gaussian noise
            noise = torch.randn_like(x) * self.noise_epsilon
        elif self.lp == 1:  # L1 ball noise
            noise = F.normalize(torch.rand_like(x) - 0.5, p=1, dim=-1) * self.noise_epsilon
        elif self.lp == 2:  # L2 ball noise
            noise = F.normalize(torch.randn_like(x), p=2, dim=-1) * self.noise_epsilon
        elif self.lp == 3:  # L-infinity ball noise
            noise = torch.sign(torch.randn_like(x)) * self.noise_epsilon
        else:
            raise ValueError(f"Invalid lp value: {self.lp}. Must be 0, 1, 2, or 3.")
        return x + noise

    def add_noise_transform(self, img):
        # Convert the image to a tensor and normalize
        img_tensor = transforms.ToTensor()(img)
        img_tensor = self.add_noise(img_tensor)
        return img_tensor

    def load_datasets(self):
        # Define the transform pipeline
        transform = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.RandomHorizontalFlip(),
            transforms.RandomCrop(224, padding=4),
            transforms.Lambda(lambda img: self.add_noise_transform(img)),  # Add noise after conversion to tensor
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))  # CIFAR-10 normalization
        ])

        cifar10_trainval = datasets.CIFAR10(root=dataset_path, train=True, download=True, transform=transform)
        cifar10_test = datasets.CIFAR10(root=dataset_path, train=False, download=True, transform=transform)

        train_size = int(0.8 * len(cifar10_trainval))
        val_size = len(cifar10_trainval) - train_size
        cifar10_train, cifar10_val = random_split(cifar10_trainval, [train_size, val_size])

        train_loader = DataLoader(cifar10_train, batch_size=self.batch_size, shuffle=True, num_workers=num_workers, generator=self.g)
        val_loader = DataLoader(cifar10_val, batch_size=self.batch_size, shuffle=False, num_workers=num_workers)
        test_loader = DataLoader(cifar10_test, batch_size=self.batch_size, shuffle=False, num_workers=num_workers)

        return train_loader, val_loader, test_loader

class CIFAR10p1(Dataset):
    def __init__(self, batch_size=32, random_seed=42, data_randseed=24):
        self.batch_size = batch_size
        self.random_seed = random_seed
        self.data_randseed = data_randseed
        self.g = self.set_seeds()

        # Fixed data path for CIFAR-10.1
        self.data_path = '/directory/datasets/cifar10point1/'
        
        # Load CIFAR-10.1 dataset
        self.data = np.load(self.data_path + 'cifar10.1_v6_data.npy')
        self.targets = np.load(self.data_path + 'cifar10.1_v6_labels.npy')

        # Transform for CIFAR-10.1
        self.transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))  # CIFAR-10 normalization
        ])

    def set_seeds(self):
        torch.manual_seed(self.random_seed)
        if torch.cuda.is_available():
            torch.cuda.manual_seed_all(self.random_seed)
        g = torch.Generator()
        g.manual_seed(self.data_randseed)
        return g

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        image = self.data[idx]
        label = torch.tensor(self.targets[idx], dtype=torch.long)  # Ensure target is a LongTensor
        if self.transform:
            image = self.transform(image)
        return image, label

    def load_datasets(self):
        dataset_size = len(self.data)
        train_size = int(0.7 * dataset_size)
        val_size = int(0.15 * dataset_size)
        test_size = dataset_size - train_size - val_size

        cifar10p1_train, cifar10p1_val, cifar10p1_test = random_split(self, [train_size, val_size, test_size], generator=self.g)

        train_loader = DataLoader(cifar10p1_train, batch_size=self.batch_size, shuffle=True, num_workers=num_workers, generator=self.g)
        val_loader = DataLoader(cifar10p1_val, batch_size=self.batch_size, shuffle=False, num_workers=num_workers)
        test_loader = DataLoader(cifar10p1_test, batch_size=self.batch_size, shuffle=False, num_workers=num_workers)

        return train_loader, val_loader, test_loader

class CIFAR10p1_Lp(CIFAR10p1):
    def __init__(self, batch_size=32, random_seed=42, data_randseed=24, lp=lp, noise_epsilon=noise_epsilon):
        super().__init__(batch_size, random_seed, data_randseed)
        self.lp = lp
        self.noise_epsilon = noise_epsilon

    def add_noise(self, x):
        if self.lp == 0:  # Gaussian noise
            noise = torch.randn_like(x) * self.noise_epsilon
        elif self.lp == 1:  # L1 ball noise
            noise = F.normalize(torch.rand_like(x) - 0.5, p=1, dim=-1) * self.noise_epsilon
        elif self.lp == 2:  # L2 ball noise
            noise = F.normalize(torch.randn_like(x), p=2, dim=-1) * self.noise_epsilon
        elif self.lp == 3:  # L-infinity ball noise
            noise = torch.sign(torch.randn_like(x)) * self.noise_epsilon
        else:
            raise ValueError(f"Invalid lp value: {self.lp}. Must be 0, 1, 2, or 3.")
        return x + noise

    def add_noise_transform(self, img):
        # Convert the image to a tensor and normalize
        img_tensor = transforms.ToTensor()(img)
        img_tensor = self.add_noise(img_tensor)
        return img_tensor

    def __getitem__(self, idx):
        image = self.data[idx]
        label = torch.tensor(self.targets[idx], dtype=torch.long)  # Ensure target is a LongTensor

        if self.transform:
            image = self.transform(image)

        return image, label

    def load_datasets(self):
        # Define the transform pipeline
        transform = transforms.Compose([
            transforms.Lambda(lambda img: self.add_noise_transform(img)),  # Add noise after conversion to tensor
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))  # CIFAR-10 normalization
        ])

        # Update transform with noise and normalization
        self.transform = transform

        dataset_size = len(self.data)
        train_size = int(0.7 * dataset_size)
        val_size = int(0.15 * dataset_size)
        test_size = dataset_size - train_size - val_size

        cifar10p1_train, cifar10p1_val, cifar10p1_test = random_split(self, [train_size, val_size, test_size], generator=self.g)

        train_loader = DataLoader(cifar10p1_train, batch_size=self.batch_size, shuffle=True, num_workers=num_workers, generator=self.g)
        val_loader = DataLoader(cifar10p1_val, batch_size=self.batch_size, shuffle=False, num_workers=num_workers)
        test_loader = DataLoader(cifar10p1_test, batch_size=self.batch_size, shuffle=False, num_workers=num_workers)

        return train_loader, val_loader, test_loader


class CIFAR100:
    def __init__(self, batch_size=32, random_seed=42, data_randseed=24):
        self.batch_size = batch_size
        self.random_seed = random_seed
        self.data_randseed = data_randseed
        self.g = self.set_seeds()

    def set_seeds(self):
        torch.manual_seed(self.random_seed)
        if torch.cuda.is_available():
            torch.cuda.manual_seed_all(self.random_seed)
        g = torch.Generator()
        g.manual_seed(self.data_randseed)
        return g

    def load_datasets(self):
        transform = transforms.Compose([
            transforms.RandomHorizontalFlip(),
            transforms.RandomCrop(32, padding=4),
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))  # CIFAR-100 normalization
        ])

        cifar100_trainval = datasets.CIFAR100(root=dataset_path, train=True, download=True, transform=transform)
        cifar100_test = datasets.CIFAR100(root=dataset_path, train=False, download=True, transform=transform)

        train_size = int(0.8 * len(cifar100_trainval))
        val_size = len(cifar100_trainval) - train_size
        cifar100_train, cifar100_val = random_split(cifar100_trainval, [train_size, val_size])

        train_loader = DataLoader(cifar100_train, batch_size=self.batch_size, shuffle=True, num_workers=num_workers, generator=self.g)
        val_loader = DataLoader(cifar100_val, batch_size=self.batch_size, shuffle=False, num_workers=num_workers)
        test_loader = DataLoader(cifar100_test, batch_size=self.batch_size, shuffle=False, num_workers=num_workers)

        return train_loader, val_loader, test_loader

class CIFAR100_Lp(CIFAR100):
    def __init__(self, batch_size=32, random_seed=42, data_randseed=24, lp=lp, noise_epsilon=noise_epsilon):
        super().__init__(batch_size, random_seed, data_randseed)
        self.lp = lp
        self.noise_epsilon = noise_epsilon

    def add_noise(self, x):
        if self.lp == 0:  # Gaussian noise
            noise = torch.randn_like(x) * self.noise_epsilon
        elif self.lp == 1:  # L1 ball noise
            noise = F.normalize(torch.rand_like(x) - 0.5, p=1, dim=-1) * self.noise_epsilon
        elif self.lp == 2:  # L2 ball noise
            noise = F.normalize(torch.randn_like(x), p=2, dim=-1) * self.noise_epsilon
        elif self.lp == 3:  # L-infinity ball noise
            noise = torch.sign(torch.randn_like(x)) * self.noise_epsilon
        else:
            raise ValueError(f"Invalid lp value: {self.lp}. Must be 0, 1, 2, or 3.")
        return x + noise

    def add_noise_transform(self, img):
        # Convert the image to a tensor and normalize
        img_tensor = transforms.ToTensor()(img)
        img_tensor = self.add_noise(img_tensor)
        return img_tensor

    def load_datasets(self):
        # Define the transform pipeline
        transform = transforms.Compose([
            transforms.RandomHorizontalFlip(),
            transforms.RandomCrop(32, padding=4),
            transforms.Lambda(lambda img: self.add_noise_transform(img)),  # Add noise after conversion to tensor
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))  # CIFAR-100 normalization
        ])

        cifar100_trainval = datasets.CIFAR100(root=dataset_path, train=True, download=True, transform=transform)
        cifar100_test = datasets.CIFAR100(root=dataset_path, train=False, download=True, transform=transform)

        train_size = int(0.8 * len(cifar100_trainval))
        val_size = len(cifar100_trainval) - train_size
        cifar100_train, cifar100_val = random_split(cifar100_trainval, [train_size, val_size])

        train_loader = DataLoader(cifar100_train, batch_size=self.batch_size, shuffle=True, num_workers=num_workers, generator=self.g)
        val_loader = DataLoader(cifar100_val, batch_size=self.batch_size, shuffle=False, num_workers=num_workers)
        test_loader = DataLoader(cifar100_test, batch_size=self.batch_size, shuffle=False, num_workers=num_workers)

        return train_loader, val_loader, test_loader

class MNIST:
    def __init__(self, batch_size=32, random_seed=42, data_randseed=24):
        self.batch_size = batch_size
        self.random_seed = random_seed
        self.data_randseed = data_randseed
        self.g = self.set_seeds()

    def set_seeds(self):
        torch.manual_seed(self.random_seed)
        if torch.cuda.is_available():
            torch.cuda.manual_seed_all(self.random_seed)
        g = torch.Generator()
        g.manual_seed(self.data_randseed)
        return g

    def load_datasets(self):
        transform = transforms.Compose([
            transforms.Grayscale(num_output_channels=3),  # Convert 1 channel to 3 channels
            transforms.Resize((224, 224)),  # Resize to the input size expected by ResNet-50
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),  # Normalize as expected by ResNet-50
        ])
        mnist_trainval = datasets.MNIST(root=dataset_path, train=True, download=True, transform=transform)
        mnist_test = datasets.MNIST(root=dataset_path, train=False, download=True, transform=transform)

        train_size = int(0.8 * len(mnist_trainval))
        val_size = len(mnist_trainval) - train_size
        mnist_train, mnist_val = random_split(mnist_trainval, [train_size, val_size])

        train_loader = DataLoader(mnist_train, batch_size=self.batch_size, shuffle=True, num_workers=num_workers, generator=self.g)
        val_loader = DataLoader(mnist_val, batch_size=self.batch_size, shuffle=False, num_workers=num_workers)
        test_loader = DataLoader(mnist_test, batch_size=self.batch_size, shuffle=False, num_workers=num_workers)

        return train_loader, val_loader, test_loader

class MNIST_Lp(MNIST):
    def __init__(self, batch_size=32, random_seed=42, data_randseed=24, lp=lp, noise_epsilon=noise_epsilon):
        super().__init__(batch_size, random_seed, data_randseed)
        self.lp = lp
        self.noise_epsilon = noise_epsilon

    def add_noise(self, x):
        if self.lp == 0:  # Gaussian noise
            noise = torch.randn_like(x) * self.noise_epsilon
        elif self.lp == 1:  # L1 ball noise
            noise = F.normalize(torch.rand_like(x) - 0.5, p=1, dim=-1) * self.noise_epsilon
        elif self.lp == 2:  # L2 ball noise
            noise = F.normalize(torch.randn_like(x), p=2, dim=-1) * self.noise_epsilon
        elif self.lp == 3:  # L-infinity ball noise
            noise = torch.sign(torch.randn_like(x)) * self.noise_epsilon
        else:
            raise ValueError(f"Invalid lp value: {self.lp}. Must be 0, 1, 2, or 3.")
        return x + noise

    def add_noise_transform(self, img):
        # Convert the image to a tensor and normalize
        img_tensor = transforms.ToTensor()(img)
        img_tensor = self.add_noise(img_tensor)
        return img_tensor

    def load_datasets(self):
        # Define the transform pipeline
        transform = transforms.Compose([
            transforms.Grayscale(num_output_channels=3),  # Convert 1 channel to 3 channels
            transforms.Resize((224, 224)),  # Resize to the input size expected by ResNet-50
            transforms.Lambda(lambda img: self.add_noise_transform(img)),  # Add noise after conversion to tensor
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),  # Normalize as expected by ResNet-50
            # transforms.Normalize((0.5,), (0.5,))  # MNIST normalization
        ])

        mnist_trainval = datasets.MNIST(root=dataset_path, train=True, download=True, transform=transform)
        mnist_test = datasets.MNIST(root=dataset_path, train=False, download=True, transform=transform)

        train_size = int(0.8 * len(mnist_trainval))
        val_size = len(mnist_trainval) - train_size
        mnist_train, mnist_val = random_split(mnist_trainval, [train_size, val_size])

        train_loader = DataLoader(mnist_train, batch_size=self.batch_size, shuffle=True, num_workers=num_workers, generator=self.g)
        val_loader = DataLoader(mnist_val, batch_size=self.batch_size, shuffle=False, num_workers=num_workers)
        test_loader = DataLoader(mnist_test, batch_size=self.batch_size, shuffle=False, num_workers=num_workers)

        return train_loader, val_loader, test_loader


class FMNIST:
    def __init__(self, batch_size=32, random_seed=42, data_randseed=24):
        self.batch_size = batch_size
        self.random_seed = random_seed
        self.data_randseed = data_randseed
        self.g = self.set_seeds()

    def set_seeds(self):
        torch.manual_seed(self.random_seed)
        if torch.cuda.is_available():
            torch.cuda.manual_seed_all(self.random_seed)
        g = torch.Generator()
        g.manual_seed(self.data_randseed)
        return g

    def load_datasets(self):
        transform = transforms.Compose([
            transforms.Grayscale(num_output_channels=3),  # Convert 1 channel to 3 channels
            transforms.Resize((224, 224)),  # Resize to the input size expected by ResNet-50
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),  # Normalize as expected by ResNet-50
        ])

        fashionmnist_trainval = datasets.FashionMNIST(root=dataset_path, train=True, download=True, transform=transform)
        fashionmnist_test = datasets.FashionMNIST(root=dataset_path, train=False, download=True, transform=transform)

        train_size = int(0.8 * len(fashionmnist_trainval))
        val_size = len(fashionmnist_trainval) - train_size
        fashionmnist_train, fashionmnist_val = random_split(fashionmnist_trainval, [train_size, val_size])

        train_loader = DataLoader(fashionmnist_train, batch_size=self.batch_size, shuffle=True, num_workers=num_workers, generator=self.g)
        val_loader = DataLoader(fashionmnist_val, batch_size=self.batch_size, shuffle=False, num_workers=num_workers)
        test_loader = DataLoader(fashionmnist_test, batch_size=self.batch_size, shuffle=False, num_workers=num_workers)

        return train_loader, val_loader, test_loader

class FMNIST_Lp(FMNIST):
    def __init__(self, batch_size=32, random_seed=42, data_randseed=24, lp=lp, noise_epsilon=noise_epsilon):
        super().__init__(batch_size, random_seed, data_randseed)
        self.lp = lp
        self.noise_epsilon = noise_epsilon

    def add_noise(self, x):
        if self.lp == 0:  # Gaussian noise
            noise = torch.randn_like(x) * self.noise_epsilon
        elif self.lp == 1:  # L1 ball noise
            noise = F.normalize(torch.rand_like(x) - 0.5, p=1, dim=-1) * self.noise_epsilon
        elif self.lp == 2:  # L2 ball noise
            noise = F.normalize(torch.randn_like(x), p=2, dim=-1) * self.noise_epsilon
        elif self.lp == 3:  # L-infinity ball noise
            noise = torch.sign(torch.randn_like(x)) * self.noise_epsilon
        else:
            raise ValueError(f"Invalid lp value: {self.lp}. Must be 0, 1, 2, or 3.")
        return x + noise

    def add_noise_transform(self, img):
        # Convert the image to a tensor and normalize
        img_tensor = transforms.ToTensor()(img)
        img_tensor = self.add_noise(img_tensor)
        return img_tensor

    def load_datasets(self):
        transform = transforms.Compose([
            transforms.Grayscale(num_output_channels=3),  # Convert 1 channel to 3 channels
            transforms.Resize((224, 224)),  # Resize to the input size expected by ResNet-50
            transforms.Lambda(lambda img: self.add_noise_transform(img)),  # Add noise after conversion to tensor
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),  # Normalize as expected by ResNet-50
            # transforms.Normalize((0.5,), (0.5,))  # MNIST normalization
        ])

        fashionmnist_trainval = datasets.FashionMNIST(root=dataset_path, train=True, download=True, transform=transform)
        fashionmnist_test = datasets.FashionMNIST(root=dataset_path, train=False, download=True, transform=transform)

        train_size = int(0.8 * len(fashionmnist_trainval))
        val_size = len(fashionmnist_trainval) - train_size
        fashionmnist_train, fashionmnist_val = random_split(fashionmnist_trainval, [train_size, val_size])

        train_loader = DataLoader(fashionmnist_train, batch_size=self.batch_size, shuffle=True, num_workers=num_workers, generator=self.g)
        val_loader = DataLoader(fashionmnist_val, batch_size=self.batch_size, shuffle=False, num_workers=num_workers)
        test_loader = DataLoader(fashionmnist_test, batch_size=self.batch_size, shuffle=False, num_workers=num_workers)

        return train_loader, val_loader, test_loader

class CustomDataset(torch.utils.data.Dataset):
    def __init__(self, features, labels, transform=None):
        self.features = features
        self.labels = labels
        self.transform = transform

    def __len__(self):
        return len(self.features)

    def __getitem__(self, idx):
        image = self.features[idx]
        label = torch.tensor(self.labels[idx], dtype=torch.long)

        if self.transform:
            image = self.transform(image)

        return image, label


class GLD23K:    
    def __init__(self, batch_size=32, random_seed=42, data_randseed=24):
        self.batch_size = batch_size
        self.random_seed = random_seed
        self.data_randseed = data_randseed
        self.train_directory = "/net/scratch/user/gld23k_build/train_np/"
        self.test_directory = "/net/scratch/user/gld23k_build/test_np/"
        self.g = self.set_seeds()

    def set_seeds(self):
        torch.manual_seed(self.random_seed)
        if torch.cuda.is_available():
            torch.cuda.manual_seed_all(self.random_seed)
        g = torch.Generator()
        g.manual_seed(self.data_randseed)
        return g

    def load_datasets(self): # Temporary, num_workers=0
        # Load and concatenate all training data
        train_features, train_labels = self.load_data(self.train_directory)

        # Split training data into training and validation sets (80/20 split)
        train_size = int(0.8 * len(train_features))
        val_size = len(train_features) - train_size

        train_dataset = CustomDataset(train_features[:train_size], train_labels[:train_size])
        val_dataset = CustomDataset(train_features[train_size:], train_labels[train_size:])

        # Load test data
        test_features, test_labels = self.load_data(self.test_directory)
        test_dataset = CustomDataset(test_features, test_labels)

        # Create data loaders
        train_loader = DataLoader(train_dataset, batch_size=self.batch_size, shuffle=True, num_workers=num_workers, generator=self.g)
        val_loader = DataLoader(val_dataset, batch_size=self.batch_size, shuffle=False, num_workers=num_workers)
        test_loader = DataLoader(test_dataset, batch_size=self.batch_size, shuffle=False, num_workers=num_workers)

        return train_loader, val_loader, test_loader

    def load_data(self, directory):
        features_list = []
        labels_list = []
        for f in os.listdir(directory):
            data = np.load(os.path.join(directory, f))
            features = data['x'].astype(np.float32).transpose(0, 3, 1, 2)
            labels = data['y'].squeeze(1)
            features_list.append(features)
            labels_list.append(labels)

        # Concatenate all data into a single array
        features_array = np.concatenate(features_list, axis=0)
        labels_array = np.concatenate(labels_list, axis=0)
        return features_array, labels_array

class GLD23K_Lp(GLD23K):
    def __init__(self, batch_size=32, random_seed=42, data_randseed=24, lp=lp, noise_epsilon=noise_epsilon):
        super().__init__(batch_size, random_seed, data_randseed)
        self.lp = lp
        self.noise_epsilon = noise_epsilon

    def add_noise(self, x):
        if self.lp == 0:  # Gaussian noise
            noise = torch.randn_like(x) * self.noise_epsilon
        elif self.lp == 1:  # L1 ball noise
            noise = F.normalize(torch.rand_like(x) - 0.5, p=1, dim=-1) * self.noise_epsilon
        elif self.lp == 2:  # L2 ball noise
            noise = F.normalize(torch.randn_like(x), p=2, dim=-1) * self.noise_epsilon
        elif self.lp == 3:  # L-infinity ball noise
            noise = torch.sign(torch.randn_like(x)) * self.noise_epsilon
        else:
            raise ValueError(f"Invalid lp value: {self.lp}. Must be 0, 1, 2, or 3.")
        return x + noise

    def add_noise_transform(self, img):
        img_tensor = torch.tensor(img, dtype=torch.float32)
        img_tensor = self.add_noise(img_tensor)
        return img_tensor

    def load_datasets(self):
        # Apply noise to the images during loading
        transform = transforms.Compose([
            transforms.Lambda(lambda img: self.add_noise_transform(img))
        ])

        # Load and concatenate all training data
        train_features, train_labels = self.load_data(self.train_directory)
        train_size = int(0.8 * len(train_features))
        val_size = len(train_features) - train_size

        train_dataset = CustomDataset(train_features[:train_size], train_labels[:train_size], transform=transform)
        val_dataset = CustomDataset(train_features[train_size:], train_labels[train_size:], transform=transform)

        # Load test data
        test_features, test_labels = self.load_data(self.test_directory)
        test_dataset = CustomDataset(test_features, test_labels, transform=transform)

        # Create data loaders
        train_loader = DataLoader(train_dataset, batch_size=self.batch_size, shuffle=True, num_workers=num_workers, generator=self.g)
        val_loader = DataLoader(val_dataset, batch_size=self.batch_size, shuffle=False, num_workers=num_workers)
        test_loader = DataLoader(test_dataset, batch_size=self.batch_size, shuffle=False, num_workers=num_workers)

        return train_loader, val_loader, test_loader