import torch
from torchvision import datasets, transforms
import numpy as np
from torch.utils.data import Dataset, Subset
import random

class CIFAR5_vs_1:
    def __init__(self, root='./data', samples_per_class=500, test_samples_per_class=100, seed=42, tranform=None):
        # Previous initialization code remains the same
        torch.manual_seed(seed)
        np.random.seed(seed)
        random.seed(seed)
        
        train_transforms = [transforms.ToTensor()]
        test_transforms = [transforms.ToTensor()]
        for transform in tranform:
            if transform == 'flip':
                train_transforms.append(transforms.RandomHorizontalFlip())
            elif transform == 'crop':
                train_transforms.append(transforms.RandomCrop(32, padding=4))
            elif transform == 'norm':
                train_transforms.append(transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)))
                test_transforms.append(transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)))
            else:
                raise ValueError(f"Transformation type {transform} is not supported")

        print(train_transforms)

        # Compose the transforms
        train_transforms = transforms.Compose(train_transforms)
        test_transforms = transforms.Compose(test_transforms)
        
        self.train_dataset = datasets.CIFAR100(root=root, train=True, 
                                             transform=train_transforms, download=True)
        self.test_dataset = datasets.CIFAR100(root=root, train=False, 
                                            transform=test_transforms, download=True)
        
        self.samples_per_class = samples_per_class
        self.test_samples_per_class = test_samples_per_class
        self.train_indices_by_class = self._organize_by_class(self.train_dataset)
        self.test_indices_by_class = self._organize_by_class(self.test_dataset)
        
        self.all_classes = list(range(100))
        random.shuffle(self.all_classes)
        self.used_classes = set()
        self.rng = np.random.RandomState(seed)
    
    class TaskSubset(Dataset):
        """Custom Dataset class that either preserves original labels or remaps them."""
        def __init__(self, dataset, indices, class_mapping, is_hard=True, preserve_labels=False):
            self.dataset = dataset
            self.indices = indices
            self.class_mapping = class_mapping
            self.is_hard = is_hard
            self.preserve_labels = preserve_labels
            
        def __getitem__(self, idx):
            image, label = self.dataset[self.indices[idx]]
            if self.preserve_labels:
                return image, label  # Keep original CIFAR-100 label
            else:
                # Remap the label to 0-4 range for hard tasks or 0 for easy tasks
                new_label = self.class_mapping[label]
                return image, new_label
            
        def __len__(self):
            return len(self.indices)
    
    def _organize_by_class(self, dataset):
        """Organize dataset indices by class."""
        indices_by_class = {i: [] for i in range(100)}
        for idx, (_, label) in enumerate(dataset):
            indices_by_class[label].append(idx)
        return indices_by_class
    
    def get_task_dataset(self, task_id, preserve_labels=False):
        """
        Get dataset for a specific task.
        Args:
            task_id: Integer identifying the task
            preserve_labels: If True, keep original CIFAR-100 labels instead of remapping
        """
        is_hard = task_id % 2 == 0
        num_classes = 5 if is_hard else 1
        
        task_rng = np.random.RandomState(self.rng.randint(10000))
        
        available_classes = list(set(self.all_classes) - self.used_classes)
        if len(available_classes) < num_classes:
            raise ValueError(f"Not enough classes remaining for task {task_id}")
        
        task_classes = task_rng.choice(available_classes, size=num_classes, replace=False)
        task_rng.shuffle(task_classes)
        self.used_classes.update(task_classes)
        
        # Create mapping from original class labels to new labels
        class_mapping = {original: idx for idx, original in enumerate(task_classes)}
        
        train_indices = []
        test_indices = []
        
        for class_idx in task_classes:
            class_train_indices = task_rng.choice(
                self.train_indices_by_class[class_idx],
                size=self.samples_per_class,
                replace=False
            )
            train_indices.extend(class_train_indices)
            
            class_test_indices = task_rng.choice(
                self.test_indices_by_class[class_idx],
                size=self.test_samples_per_class,
                replace=False
            )
            test_indices.extend(class_test_indices)
        
        # Create subset datasets
        train_subset = self.TaskSubset(
            self.train_dataset, train_indices, class_mapping, 
            is_hard=is_hard, preserve_labels=preserve_labels
        )
        test_subset = self.TaskSubset(
            self.test_dataset, test_indices, class_mapping, 
            is_hard=is_hard, preserve_labels=preserve_labels
        )
        
        # Print task information
        task_type = "hard" if is_hard else "easy"
        class_names = self.get_class_names(task_classes)
        print(f"\nTask {task_id} ({task_type}):")
        print("Selected classes:", class_names)
        if preserve_labels:
            print("Using original CIFAR-100 labels:", task_classes)
        else:
            print("Label mapping:", {name: label for label, name in enumerate(class_names)})
        
        return train_subset, test_subset, task_classes

    def get_class_names(self, class_indices):
        """Get class names for given class indices."""
        return [self.train_dataset.classes[idx] for idx in class_indices]
        
    def reset(self):
        """Reset the used classes and reshuffle available classes."""
        self.used_classes = set()
        random.shuffle(self.all_classes)
        self.rng = np.random.RandomState(np.random.randint(10000))