import torch
import numpy as np
from torchvision.datasets import CIFAR100
from torch.utils.data import Subset, DataLoader, Dataset
import torchvision.transforms as transforms
from typing import List, Optional, Tuple, Union, Dict
import random

class ClassIncrementalCIFAR100:
    def __init__(
        self,
        root: str = './data',
        classes_per_task: int = 5,
        num_tasks: int = 20,
        samples_per_class: Optional[int] = None,
        seed: int = 42,
        transform=None,
        include_val_set: bool = False,  # New parameter to include validation set
        val_ratio: float = 0.1,  # Validation ratio (default: 10%)
    ):
        """
        Class incremental CIFAR100 dataset for continual learning.
        ...
        """
        self.root = root
        self.classes_per_task = classes_per_task
        self.num_tasks = num_tasks
        self.samples_per_class = samples_per_class
        self.seed = seed
        self.include_val_set = include_val_set
        self.val_ratio = val_ratio
        
        train_transforms = [transforms.ToTensor()]
        test_transforms = [transforms.ToTensor()]
        # Default transforms if none provided
        for transf in transform:
            if transf == 'norm':
                train_transforms.append(transforms.Normalize(mean=(0.5071, 0.4867, 0.4408), std=(0.2675, 0.2565, 0.2761)))
                test_transforms.append(transforms.Normalize(mean=(0.5071, 0.4867, 0.4408), std=(0.2675, 0.2565, 0.2761)))
            elif transf == 'flip':
                train_transforms.append(transforms.RandomHorizontalFlip())
            elif transf == 'crop':
                train_transforms.append(transforms.RandomCrop(32, padding=4, padding_mode='reflect'))
            elif transf == 'rotate':
                train_transforms.append(transforms.RandomRotation(degrees=(0, 15)))
        
        self.transform_train = transforms.Compose(train_transforms)
        self.transform_test = transforms.Compose(test_transforms)
        
            
        self.target_transform = None
        
        # When using a validation set, create two instances of the CIFAR100 dataset:
        # - One for training (with training transforms)
        # - One for validation (with test transforms)
        if self.include_val_set:
            self.train_dataset = CIFAR100(
                root=root, train=True, download=True,
                transform=self.transform_train,
                target_transform=self.target_transform
            )
            self.val_dataset = CIFAR100(
                root=root, train=True, download=True,
                transform=self.transform_test,
                target_transform=self.target_transform
            )
        else:
            self.train_dataset = CIFAR100(
                root=root, train=True, download=True,
                transform=self.transform_train,
                target_transform=self.target_transform
            )
        
        # Test dataset remains unchanged
        self.test_dataset = CIFAR100(
            root=root, train=False, download=True,
            transform=self.transform_test,
            target_transform=self.target_transform
        )
        
        # Shuffle classes
        self.class_order = list(range(100))
        random.seed(seed)
        random.shuffle(self.class_order)
        
        # Create class to task mapping
        self.class_to_task = {}
        for task_id in range(num_tasks):
            start_class = task_id * classes_per_task
            end_class = (task_id + 1) * classes_per_task
            for class_idx in range(start_class, end_class):
                if class_idx < len(self.class_order):
                    self.class_to_task[self.class_order[class_idx]] = task_id
        
        # Create class indices for both training and test data
        # (Indices for validation will be derived from training indices)
        train_class_indices_full = self._get_class_indices(self.train_dataset)
        self.test_class_indices = self._get_class_indices(self.test_dataset)
        
        # Split training data into train and validation if needed
        if self.include_val_set:
            self.train_class_indices = {}
            self.val_class_indices = {}
            
            for class_label, indices in train_class_indices_full.items():
                # Use a per-class seeded random for reproducibility
                rng = random.Random(self.seed + class_label)
                shuffled_indices = indices.copy()
                rng.shuffle(shuffled_indices)
                
                # Apply sample limit if specified
                if self.samples_per_class is not None and len(shuffled_indices) > self.samples_per_class:
                    shuffled_indices = shuffled_indices[:self.samples_per_class]
                
                # Determine number of validation samples
                val_size = int(len(shuffled_indices) * self.val_ratio)
                if len(shuffled_indices) == 500 and self.val_ratio == 0.1:
                    val_size = 50  # Ensure exactly 50 samples for validation
                
                # Split indices between validation and training
                self.val_class_indices[class_label] = shuffled_indices[:val_size]
                self.train_class_indices[class_label] = shuffled_indices[val_size:]
        else:
            self.train_class_indices = train_class_indices_full
            if self.samples_per_class is not None:
                self._limit_samples_per_class()
    
    def _get_class_indices(self, dataset):
        """Get indices of samples for each class."""
        class_indices = {}
        for idx, (_, target) in enumerate(dataset):
            if target not in class_indices:
                class_indices[target] = []
            class_indices[target].append(idx)
        return class_indices

    def _limit_samples_per_class(self):
        """Limit the number of samples per class."""
        for class_label in self.train_class_indices:
            if len(self.train_class_indices[class_label]) > self.samples_per_class:
                rng = random.Random(self.seed + class_label)
                self.train_class_indices[class_label] = rng.sample(
                    self.train_class_indices[class_label], self.samples_per_class
                )
    
    def get_task_classes(self, task_id: int) -> List[int]:
        """Get the classes for a specific task."""
        start_class = task_id * self.classes_per_task
        end_class = min((task_id + 1) * self.classes_per_task, len(self.class_order))
        return [self.class_order[i] for i in range(start_class, end_class)]
    
    def get_task_datasets(self, task_id: int, include_previous_data: bool = False, dataset_type: str = 'train') -> Dataset:
        """
        Get dataset for a specific task and dataset type.
        Args:
            task_id: The task ID (0-indexed)
            include_previous_data: If True, include data from previous tasks
            dataset_type: 'train', 'val', or 'test'
        Returns:
            Dataset for the given task and type.
        """
        if task_id >= self.num_tasks:
            raise ValueError(f"Task ID {task_id} is out of range (max: {self.num_tasks-1})")
        
        # Determine which classes belong to this task (and possibly previous tasks)
        if include_previous_data:
            task_classes = []
            for tid in range(task_id + 1):
                task_classes.extend(self.get_task_classes(tid))
        else:
            task_classes = self.get_task_classes(task_id)
        
        # Select the appropriate indices and base dataset
        indices = []
        if dataset_type == 'train':
            class_indices_dict = self.train_class_indices
            base_dataset = self.train_dataset
        elif dataset_type == 'val' and self.include_val_set:
            class_indices_dict = self.val_class_indices
            # Use the dedicated validation dataset with test transforms
            base_dataset = self.val_dataset
        elif dataset_type == 'test':
            class_indices_dict = self.test_class_indices
            base_dataset = self.test_dataset
        else:
            raise ValueError(f"Invalid dataset_type '{dataset_type}' or validation set not enabled")
        
        for class_id in task_classes:
            if class_id in class_indices_dict:
                indices.extend(class_indices_dict[class_id])
        
        # Create a subset for the given indices
        return Subset(base_dataset, indices), 
    
    def get_data_loaders(
        self, 
        task_id: int, 
        include_previous_data: bool = False, 
        batch_size: int = 128, 
        shuffle: bool = True, 
        num_workers: int = 4
    ) -> Dict[str, DataLoader]:
        """
        Get data loaders for a specific task.
        Returns loaders for 'train', 'val' (if enabled) and 'test'.
        """
        
        # Training loader
        train_dataset = self.get_task_datasets(task_id, include_previous_data, 'train')
        
        
        # Test loader
        test_dataset = self.get_task_datasets(task_id, include_previous_data, 'test')
        
        return train_dataset, test_dataset
    
    def get_full_test_loader(self, batch_size: int = 128, num_workers: int = 4) -> DataLoader:
        """Get a test loader that contains all classes."""
        test_loader = DataLoader(
            self.test_dataset, batch_size=batch_size,
            shuffle=False, num_workers=num_workers
        )
        return test_loader
    
    def get_task_info(self) -> Dict:
        """Get information about the tasks and class distribution."""
        task_info = {}
        for task_id in range(self.num_tasks):
            classes = self.get_task_classes(task_id)
            train_dataset = self.get_task_datasets(task_id, dataset_type='train')
            test_dataset = self.get_task_datasets(task_id, dataset_type='test')
            task_info_dict = {
                "classes": classes,
                "train_samples": len(train_dataset),
                "test_samples": len(test_dataset)
            }
            if self.include_val_set:
                val_dataset = self.get_task_datasets(task_id, dataset_type='val')
                task_info_dict["val_samples"] = len(val_dataset)
            task_info[f"Task {task_id}"] = task_info_dict
        return task_info


# Example usage
if __name__ == "__main__":
    # Initialize the dataset
    cifar_cl = ClassIncrementalCIFAR100(
        root='./data',
        classes_per_task=5,
        num_tasks=20,
        samples_per_class=500,  # Limit to 500 samples per class (or None for all)
        seed=42
    )
    
    # Print task info
    task_info = cifar_cl.get_task_info()
    for task, info in task_info.items():
        print(f"{task}: Classes {info['classes']}, Train samples: {info['train_samples']}, Test samples: {info['test_samples']}")
    
    # Get data loaders for specific task (e.g., task 0)
    task_id = 0
    
    # Without previous data
    train_loader, test_loader = cifar_cl.get_data_loader(
        task_id=task_id,
        include_previous_data=False,
        batch_size=64
    )
    print(f"\nTask {task_id} (without previous data):")
    print(f"Train batches: {len(train_loader)}, Test batches: {len(test_loader)}")
    
    # With previous data (will be same for task 0)
    train_loader, test_loader = cifar_cl.get_data_loader(
        task_id=task_id,
        include_previous_data=True,
        batch_size=64
    )
    print(f"\nTask {task_id} (with previous data):")
    print(f"Train batches: {len(train_loader)}, Test batches: {len(test_loader)}")
    
    # Get data for task 3, including previous data
    if cifar_cl.num_tasks > 3:
        train_loader, test_loader = cifar_cl.get_data_loader(
            task_id=3,
            include_previous_data=True,
            batch_size=64
        )
        print(f"\nTask 3 (with previous data):")
        print(f"Train batches: {len(train_loader)}, Test batches: {len(test_loader)}")