import torch
from torch.utils.data import DataLoader, random_split, Dataset, Subset
from torchvision import datasets, transforms
import os
import torch.nn.functional as F
import numpy as np
from PIL import Image
# from .options import args_parser

# Ensure the directory exists
dataset_path = '/directory/datasets'
os.makedirs(dataset_path, exist_ok=True)

# args = args_parser()

num_workers = 4

class CIFAR10_oneclass:
    def __init__(self, batch_size=32, random_seed=42, data_randseed=24):
        self.batch_size = batch_size
        self.random_seed = random_seed
        self.data_randseed = data_randseed
        self.g = self.set_seeds()
    
    def set_seeds(self):
        torch.manual_seed(self.random_seed)
        if torch.cuda.is_available():
            torch.cuda.manual_seed_all(self.random_seed)
        g = torch.Generator()
        g.manual_seed(self.data_randseed)
        return g
    
    def load_datasets(self):
        # Define transformations with resizing to 224x224
        transform_train = transforms.Compose([
            transforms.Resize((224, 224)),                      # Resize to 224x224
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), 
                                 (0.2023, 0.1994, 0.2010))    # CIFAR-10 normalization
        ])
        
        transform_test = transforms.Compose([
            transforms.Resize((224, 224)),                      # Resize to 224x224
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), 
                                 (0.2023, 0.1994, 0.2010))    # CIFAR-10 normalization
        ])
        
        # Load the datasets with the specified transformations
        cifar10_trainval = datasets.CIFAR10(
            root=dataset_path, train=True, download=True, transform=transform_train
        )
        cifar10_test = datasets.CIFAR10(
            root=dataset_path, train=False, download=True, transform=transform_test
        )
        
        # Split into training and validation sets (80-20 split)
        train_size = int(0.8 * len(cifar10_trainval))
        val_size = len(cifar10_trainval) - train_size
        cifar10_train, cifar10_val = random_split(
            cifar10_trainval, [train_size, val_size], generator=self.g
        )
        
        # Organize DataLoaders in a dictionary
        dataloaders = {i: {'train': None, 'val': None, 'test': None} for i in range(10)}
        
        # Function to get indices for a specific class in a dataset
        # def get_class_indices(dataset, class_label):
        #     targets = np.array(dataset.dataset.targets if isinstance(dataset.dataset, datasets.CIFAR10) else dataset.dataset.targets)
        #     if isinstance(dataset, Subset):
        #         indices = dataset.indices
        #         targets = targets[indices]
        #     return np.where(targets == class_label)[0]
        def get_class_indices(dataset, class_label):
            if isinstance(dataset, Subset):
                # If dataset is a Subset, access the underlying dataset's targets
                targets = np.array(dataset.dataset.targets)
                # Use the subset's indices to filter the targets
                targets = targets[dataset.indices]
            else:
                # If dataset is a full dataset (e.g., datasets.CIFAR10), access targets directly
                targets = np.array(dataset.targets)
            
            # Return the indices where targets match the class_label
            return np.where(targets == class_label)[0]
        
        for class_label in range(10):
            # Get indices for the current class in training, validation, and test sets
            train_indices = get_class_indices(cifar10_train, class_label)
            val_indices = get_class_indices(cifar10_val, class_label)
            test_indices = get_class_indices(cifar10_test, class_label)
            
            # Create subset datasets for the current class
            train_subset = Subset(cifar10_train, train_indices)
            val_subset = Subset(cifar10_val, val_indices)
            test_subset = Subset(cifar10_test, test_indices)
            
            # Create DataLoaders for the current class
            dataloaders[class_label]['train'] = DataLoader(
                train_subset, batch_size=self.batch_size, shuffle=True, 
                num_workers=num_workers, generator=self.g
            )
            dataloaders[class_label]['val'] = DataLoader(
                val_subset, batch_size=self.batch_size, shuffle=False, 
                num_workers=num_workers
            )
            dataloaders[class_label]['test'] = DataLoader(
                test_subset, batch_size=self.batch_size, shuffle=False, 
                num_workers=num_workers
            )
        
        return dataloaders

class CIFAR10_nineclass:
    def __init__(self, batch_size=32, random_seed=42, data_randseed=24):
        """
        Initializes the CIFAR10_nineclass object with batch size and random seeds.

        Args:
            batch_size (int): Number of samples per batch.
            random_seed (int): Seed for reproducibility.
            data_randseed (int): Seed for data splitting reproducibility.
        """
        self.batch_size = batch_size
        self.random_seed = random_seed
        self.data_randseed = data_randseed
        self.g = self.set_seeds()
    
    def set_seeds(self):
        """
        Sets the random seeds for reproducibility.

        Returns:
            torch.Generator: A generator with the specified data_randseed.
        """
        torch.manual_seed(self.random_seed)
        if torch.cuda.is_available():
            torch.cuda.manual_seed_all(self.random_seed)
        g = torch.Generator()
        g.manual_seed(self.data_randseed)
        return g

    def load_datasets(self):
        """
        Loads the CIFAR-10 datasets, applies transformations, splits into train and validation,
        and creates DataLoaders excluding one class at a time.

        Returns:
            dict: A dictionary where each key is a class label (0-9) and each value is another
                dictionary containing 'train', 'val', and 'test' DataLoaders excluding that class.
        """
        # Define transformations with resizing to 224x224
        transform_train = transforms.Compose([
            transforms.Resize((224, 224)),                      # Resize to 224x224
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), 
                                (0.2023, 0.1994, 0.2010))    # CIFAR-10 normalization
        ])
        
        transform_test = transforms.Compose([
            transforms.Resize((224, 224)),                      # Resize to 224x224
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), 
                                (0.2023, 0.1994, 0.2010))    # CIFAR-10 normalization
        ])
        
        # Load the datasets with the specified transformations
        cifar10_trainval = datasets.CIFAR10(
            root=dataset_path, train=True, download=True, transform=transform_train
        )
        cifar10_test = datasets.CIFAR10(
            root=dataset_path, train=False, download=True, transform=transform_test
        )
        
        # Organize DataLoaders in a dictionary
        dataloaders = {i: {'train': None, 'val': None, 'test': None} for i in range(10)}
        
        for exclude_class in range(10):
            # Get indices excluding the current class in trainval and test sets
            targets_trainval = np.array(cifar10_trainval.targets)
            trainval_indices = np.where(targets_trainval != exclude_class)[0]
            
            # Shuffle and split trainval_indices into training and validation indices (80-20 split)
            np.random.seed(self.data_randseed)
            np.random.shuffle(trainval_indices)
            train_size = int(0.8 * len(trainval_indices))
            train_indices = trainval_indices[:train_size]
            val_indices = trainval_indices[train_size:]
            
            # Get indices excluding the current class in test set
            targets_test = np.array(cifar10_test.targets)
            test_indices = np.where(targets_test != exclude_class)[0]
            
            # Create Subsets directly from the original datasets
            train_subset = Subset(cifar10_trainval, train_indices)
            val_subset = Subset(cifar10_trainval, val_indices)
            test_subset = Subset(cifar10_test, test_indices)
            
            # Create DataLoaders excluding the current class
            dataloaders[exclude_class]['train'] = DataLoader(
                train_subset, batch_size=self.batch_size, shuffle=True, 
                num_workers=num_workers
            )
            dataloaders[exclude_class]['val'] = DataLoader(
                val_subset, batch_size=self.batch_size, shuffle=False, 
                num_workers=num_workers
            )
            dataloaders[exclude_class]['test'] = DataLoader(
                test_subset, batch_size=self.batch_size, shuffle=False, 
                num_workers=num_workers
            )
        
        return dataloaders
