import torch
from torchvision import datasets, transforms
import numpy as np
from torch.utils.data import Dataset, Subset
import random

class CIFAR10Continual:
    def __init__(self, root='./data', samples_per_class=500, test_samples_per_class=100, seed=42, tranform=None):
        # Set random seed for reproducibility
        torch.manual_seed(seed)
        np.random.seed(seed)
        random.seed(seed)
        
        train_transforms = [transforms.ToTensor()]
        test_transforms = [transforms.ToTensor()]
        for transform in tranform:
            if transform == 'flip':
                train_transforms.append(transforms.RandomHorizontalFlip())
            elif transform == 'crop':
                train_transforms.append(transforms.RandomCrop(32, padding=4))
            elif transform == 'norm':
                train_transforms.append(transforms.Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261)))
                test_transforms.append(transforms.Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261)))
            else:
                raise ValueError(f"Transformation type {transform} is not supported")

        print(train_transforms)
        not_aug_train_transforms = transforms.Compose([transforms.ToTensor(), 
                                                       transforms.Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261))])
        # Compose the transforms
        train_transforms = transforms.Compose(train_transforms)
        test_transforms = transforms.Compose(test_transforms)
        
        # Load CIFAR-10
        self.train_dataset = datasets.CIFAR10(root=root, train=True, 
                                             transform=train_transforms, download=True)
        
        self.not_aug_train_dataset = datasets.CIFAR10(root=root, train=True, 
                                             transform=not_aug_train_transforms, download=True)
        
        self.test_dataset = datasets.CIFAR10(root=root, train=False, 
                                            transform=test_transforms, download=True)
        
        self.samples_per_class = samples_per_class
        self.test_samples_per_class = test_samples_per_class
        
        # Create class-wise indices
        self.train_indices_by_class = self._organize_by_class(self.train_dataset)
        self.test_indices_by_class = self._organize_by_class(self.test_dataset)
        
        # Shuffle all classes at initialization
        self.all_classes = list(range(10))
        random.shuffle(self.all_classes)
        
    def _organize_by_class(self, dataset):
        """Organize dataset indices by class."""
        indices_by_class = {i: [] for i in range(10)}
        for idx, (_, label) in enumerate(dataset):
            indices_by_class[label].append(idx)
        return indices_by_class
    
    class TaskSubset(Dataset):
        """Custom Dataset class that optionally remaps labels to 0-4 range."""
        def __init__(self, dataset, not_aug_train_dataset, indices, class_mapping, preserve_labels=False):
            self.dataset = dataset
            self.indices = indices
            self.not_aug_train_dataset = not_aug_train_dataset
            self.class_mapping = class_mapping
            self.preserve_labels = preserve_labels
            
        def __getitem__(self, idx):
            image, label = self.dataset[self.indices[idx]]
            raw_image, _ = self.not_aug_train_dataset[self.indices[idx]]
            # Only remap the label if preserve_labels is False
            if not self.preserve_labels:
                label = self.class_mapping[label]
            #return image, raw_image, label
            return image, label
            
        def __len__(self):
            return len(self.indices)
    
    def get_task_dataset(self, task_id, num_classes=2, preserve_labels=False):
        """
        Get dataset for a specific task with num_classes random classes.
        
        Args:
            task_id: The ID of the task
            num_classes: Number of classes per task
            preserve_labels: If True, keep original CIFAR-10 labels (0-99)
                           If False, remap labels to 0-(num_classes-1) range
        """
        # Get the pre-shuffled classes for this task
        start_idx = task_id * num_classes
        end_idx = (task_id + 1) * num_classes
        
        #print(start_idx, ' -- ', end_idx)
        
        # if end_idx % 100 == 0:
        #     task_classes = self.all_classes[start_idx % 100:]
        # else:
        #     task_classes = self.all_classes[start_idx % 100:end_idx % 100]
        
        task_classes = self.all_classes[start_idx:end_idx]
    
        
        # Create mapping from original class labels to 0-(num_classes-1)
        class_mapping = {original: idx for idx, original in enumerate(task_classes)}
        
        train_indices = []
        test_indices = []
        
        for class_idx in task_classes:
            # Sample training indices for this class
            class_train_indices = np.random.choice(
                self.train_indices_by_class[class_idx],
                size=self.samples_per_class,
                replace=False
            )
            train_indices.extend(class_train_indices)
            
            # Sample test indices for this class
            class_test_indices = np.random.choice(
                self.test_indices_by_class[class_idx],
                size=self.test_samples_per_class,
                replace=False
            )
            test_indices.extend(class_test_indices)
        
        # Create subset datasets with optional label remapping
        train_subset = self.TaskSubset(self.train_dataset, self.not_aug_train_dataset, train_indices, 
                                     class_mapping, preserve_labels)
        not_aug_train_subset = self.TaskSubset(self.not_aug_train_dataset, self.not_aug_train_dataset, train_indices, 
                                     class_mapping, preserve_labels)
        test_subset = self.TaskSubset(self.test_dataset, self.test_dataset, test_indices, 
                                    class_mapping, preserve_labels)
        
        return train_subset, test_subset, not_aug_train_subset, task_classes

    def get_class_names(self, class_indices):
        """Get class names for given class indices."""
        class_to_idx = {
            idx: name for idx, name in 
            enumerate(self.train_dataset.classes)
        }
        return [class_to_idx[idx] for idx in class_indices]