import torchvision.transforms as transforms
import torchvision.datasets as datasets
from typing import Callable, Tuple, Any
from torch.utils.data import Dataset

#from .continual_imagenet import *
from .five_one_cifar100 import *
from .collas_utils import *
from .cifar100_continual import *
from .continual_tiny_imagenet import *
from .cifar100_class_incremental import *
from .cifar10_continual import * 
import math
from torchvision.datasets.vision import VisionDataset

import os
import pickle
import numpy as np
from torchvision.datasets import VisionDataset
from PIL import Image

class MiniImageNet(VisionDataset):

    def __init__(self, root='/localhome/srr8/project/iclr_rebuttal/17/tasks/data', split='train',
                 only_targets=False,
                 transform=None,
                 target_transform=None):
        """
        root: root folder that contains `mini_imagenet/1/`
        split: 'train' | 'val' | 'test'
        """
        assert split in {'train', 'val', 'test'}
        super().__init__(root, transform=transform, target_transform=target_transform)

        cache_path = os.path.join(
            root, 'mini_imagenet', '1',
            f'mini-imagenet-cache-{split}.pkl'
        )

        with open(cache_path, 'rb') as f:
            cache = pickle.load(f)     # usually {'image_data': ..., 'class_dict': ...}

        image_data = cache['image_data']          # (N, H, W, C) numpy array
        class_dict = cache['class_dict']          # {class_name: [idx, idx, ...]}

        # build targets array: one label per image
        self.targets = np.empty(len(image_data), dtype=np.int64)
        for label, (_, idxs) in enumerate(class_dict.items()):
            self.targets[idxs] = label

        # store data if requested
        self.data = None if only_targets else image_data

    def __getitem__(self, index):
        if self.data is None:
            raise RuntimeError("Dataset was created with only_targets=True; no image data loaded.")

        img = self.data[index]
        target = int(self.targets[index])

        # cache files usually store images as uint8 HWC numpy arrays
        img = Image.fromarray(img)

        if self.transform is not None:
            img = self.transform(img)
        if self.target_transform is not None:
            target = self.target_transform(target)

        return img, target

    def __len__(self):
        # length is defined by number of targets / images
        return len(self.targets)

def manage_task_id(task_id, reoccurring_type, reoccurring_freq):
    if reoccurring_type == 'none':
        return task_id, False
    elif reoccurring_type == "fixed_seq":
        if task_id % reoccurring_freq == 0:
            return task_id % reoccurring_freq, True
        else:
            return task_id % reoccurring_freq, False
    elif reoccurring_type == "random_seq":
        if task_id % reoccurring_freq == 0:
            return task_id % reoccurring_freq, True
        else:
            return task_id, False
    else:
        raise ValueError(f"Reoccurring type {reoccurring_type} is not supported")

def configure_dataset(task_config, arch_config, use_hat, args) -> Tuple[Callable, int, int]:
    """
    Configure dataset based on benchmark type and return necessary components.
    
    Args:
        config: Configuration object containing benchmark and training settings
        args: Arguments object containing seed and other runtime settings
    
    Returns:
        Tuple containing:
        - get_task_dataset function
        - input dimension
        - number of classes
    """
    
    if task_config.benchmark == "permuted_MNIST":
        train_transforms = [transforms.ToTensor()]
        test_transforms = [transforms.ToTensor()]
        for transform in task_config.tranform:
            if transform == 'flip':
                train_transforms.append(transforms.RandomHorizontalFlip())
            elif transform == 'crop':
                train_transforms.append(transforms.RandomCrop(28, padding=4))
            elif transform == 'norm':
                train_transforms.append(transforms.Normalize((0.1307,), (0.3081,)))
                test_transforms.append(transforms.Normalize((0.1307,), (0.3081,)))
            else:
                raise ValueError(f"Transformation type {transform} is not supported")

        not_aug_train_transforms = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
        print(train_transforms)

        train_transforms = transforms.Compose(train_transforms)
        test_transforms = transforms.Compose(test_transforms)
            
        train_dataset = datasets.MNIST(root="./data", train=True, download=True, transform=train_transforms)
        not_aug_train_dataset = datasets.MNIST(root="./data", train=True, download=True, transform=not_aug_train_transforms)
        if task_config.limit is not None:
            train_dataset = get_limited_dataset(
                    train_dataset, 
                    task_config.limit, 
                    seed= args.seed
                )
            not_aug_train_dataset = get_limited_dataset(
                    not_aug_train_dataset, 
                    task_config.limit, 
                    seed= args.seed
                )
            
        test_dataset = datasets.MNIST(root="./data", train=False, download=True, transform=test_transforms)
        input_dim = task_config.input_shape
        num_classes = 10
        
        def get_task_dataset(task_id):
            task_id, reoccured = manage_task_id(task_id, task_config.reoccurring_type, task_config.reoccurring_freq)
            permutation = get_permutation(seed=task_id, size=math.prod(input_dim))
            perm_train_dataset = permute_mnist_dataset(train_dataset, permutation)
            perm_test_dataset = permute_mnist_dataset(test_dataset, permutation)
            perm_not_aug_train_dataset = permute_mnist_dataset(not_aug_train_dataset, permutation)
            return perm_train_dataset, perm_test_dataset, perm_not_aug_train_dataset, reoccured

    elif task_config.benchmark == "random_MNIST":
        train_transforms = [transforms.ToTensor()]
        test_transforms = [transforms.ToTensor()]
        for transform in task_config.tranform:
            if transform == 'flip':
                train_transforms.append(transforms.RandomHorizontalFlip())
            elif transform == 'crop':
                train_transforms.append(transforms.RandomCrop(28, padding=4))
            elif transform == 'norm':
                train_transforms.append(transforms.Normalize((0.1307,), (0.3081,)))
                test_transforms.append(transforms.Normalize((0.1307,), (0.3081,)))
            else:
                raise ValueError(f"Transformation type {transform} is not supported")
        
        print(train_transforms)
        not_aug_train_transforms = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])

        train_transforms = transforms.Compose(train_transforms)
        test_transforms = transforms.Compose(test_transforms)

        train_dataset = datasets.MNIST(root="./data", train=True, download=True, transform=train_transforms)
        not_aug_train_dataset = datasets.MNIST(root="./data", train=True, download=True, transform=not_aug_train_transforms)
        test_dataset = datasets.MNIST(root="./data", train=False, download=True, transform=test_transforms)
        input_dim = task_config.input_shape
        num_classes = 10
        
        def get_task_dataset(task_id):
            task_id, reoccured = manage_task_id(task_id, task_config.reoccurring_type, task_config.reoccurring_freq)
            train_task_dataset = RandomLabelDataset(
                train_dataset,
                seed=args.seed + task_id,
                num_samples=task_config.limit if task_config.limit else None
            )
            not_aug_train_task_dataset = RandomLabelDataset(
                not_aug_train_dataset,
                seed=args.seed + task_id,
                num_samples=task_config.limit if task_config.limit else None
            )
            test_task_dataset = RandomLabelDataset(test_dataset, seed= args.seed + task_id)
            return train_task_dataset, test_task_dataset, not_aug_train_task_dataset, reoccured

    elif task_config.benchmark == "shuffle_cifar10":
        # We'll use CIFAR10 with randomly permuted labels for each task
        train_transforms = [transforms.ToTensor()]
        test_transforms = [transforms.ToTensor()]
        for transform in task_config.tranform:
            if transform == 'flip':
                train_transforms.append(transforms.RandomHorizontalFlip())
            elif transform == 'crop':
                train_transforms.append(transforms.RandomCrop(32, padding=4))
            elif transform == 'norm':
                train_transforms.append(transforms.Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261)))
                test_transforms.append(transforms.Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261)))
            else:
                raise ValueError(f"Transformation type {transform} is not supported")

        not_aug_train_transforms = transforms.Compose([transforms.ToTensor(), 
                                                       transforms.Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261))])
        print(train_transforms)
        train_transforms = transforms.Compose(train_transforms)
        test_transforms = transforms.Compose(test_transforms)

        train_dataset = datasets.CIFAR10(root="./data", train=True, download=True, transform=train_transforms)
        not_aug_train_dataset = datasets.CIFAR10(root="./data", train=True, download=True, transform=not_aug_train_transforms)
        test_dataset = datasets.CIFAR10(root="./data", train=False, download=True, transform=test_transforms)
        input_dim = task_config.input_shape
        num_classes = 10
        
        # If config.limit is 5000, this will create a dataset with 500 samples per class
        if task_config.limit is not None:
            samples_per_class = task_config.limit // 10 
        else:
            samples_per_class = None
        train_dataset = create_balanced_dataset(train_dataset, samples_per_class, seed=args.seed)
        not_aug_train_dataset = create_balanced_dataset(not_aug_train_dataset, samples_per_class, seed=args.seed)
        
        def get_task_dataset(task_id):
            task_id, reoccured = manage_task_id(task_id, task_config.reoccurring_type, task_config.reoccurring_freq)
            # create a label permutation
            label_perm = get_label_permutation(num_classes=10, seed=task_id)
            perm_train_dataset = apply_label_permutation(train_dataset, label_perm, flatten= False)
            perm_not_aug_train_dataset = apply_label_permutation(not_aug_train_dataset, label_perm, flatten= False)
            perm_test_dataset = apply_label_permutation(test_dataset, label_perm, flatten= False)
            return perm_train_dataset, perm_test_dataset, perm_not_aug_train_dataset, reoccured
    
    elif task_config.benchmark == "shuffle_mini_imagenet":
        # We'll use Mini-ImageNet with randomly permuted labels for each task
        train_transforms = [transforms.ToTensor()]
        test_transforms = [transforms.ToTensor()]

        # Example: use config the same way as CIFAR10
        for transform in task_config.tranform:
            if transform == 'flip':
                train_transforms.append(transforms.RandomHorizontalFlip())
            elif transform == 'crop':
                # mini-ImageNet is 84x84; you can tune this if you want
                train_transforms.insert(0, transforms.RandomResizedCrop(84, scale=(0.8, 1.0)))
            elif transform == 'norm':
                # Use ImageNet statistics (commonly also used for mini-ImageNet)
                imagenet_mean = (0.485, 0.456, 0.406)
                imagenet_std  = (0.229, 0.224, 0.225)
                train_transforms.append(transforms.Normalize(imagenet_mean, imagenet_std))
                test_transforms.insert(0, transforms.Resize(92))
                test_transforms.append(transforms.CenterCrop(84))
                test_transforms.append(transforms.Normalize(imagenet_mean, imagenet_std))
            else:
                raise ValueError(f"Transformation type {transform} is not supported")

        not_aug_train_transforms = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.485, 0.456, 0.406),
                                (0.229, 0.224, 0.225))
        ])

        train_transforms = transforms.Compose(train_transforms)
        test_transforms = transforms.Compose(test_transforms)

        # Replace CIFAR10 with your Mini-ImageNet dataset class
        # Adjust import/arguments to match your implementation
        # Example if you have a MiniImageNet(root, split, transform, download):
        train_dataset = MiniImageNet(
                                    split="train",
                                    #download=True,
                                    transform=train_transforms)
        not_aug_train_dataset = MiniImageNet(
                                            split="train",
                                            #download=True,
                                            transform=not_aug_train_transforms)
        test_dataset = MiniImageNet(
                                    split="test",
                                    #download=True,
                                    transform=test_transforms)

        input_dim = task_config.input_shape

        # Infer num_classes from the dataset if possible
        num_classes = task_config.num_classes

        # If config.limit is, say, 5000, this will create a dataset with 5000/num_classes samples per class
        if task_config.limit is not None:
            samples_per_class = task_config.limit // num_classes
        else:
            samples_per_class = None

        train_dataset = create_balanced_dataset(train_dataset, samples_per_class, seed=args.seed)
        not_aug_train_dataset = create_balanced_dataset(not_aug_train_dataset, samples_per_class, seed=args.seed)

        def get_task_dataset(task_id):
            task_id, reoccured = manage_task_id(task_id,
                                                task_config.reoccurring_type,
                                                task_config.reoccurring_freq)

            # create a label permutation for THIS task
            label_perm = get_label_permutation(num_classes=num_classes, seed=task_id)

            perm_train_dataset = apply_label_permutation(train_dataset, label_perm, flatten=False)
            perm_not_aug_train_dataset = apply_label_permutation(not_aug_train_dataset, label_perm, flatten=False)
            perm_test_dataset = apply_label_permutation(test_dataset, label_perm, flatten=False)

            return perm_train_dataset, perm_test_dataset, perm_not_aug_train_dataset, reoccured


    elif task_config.benchmark == "random_label_cifar10":
        # Build a list of transforms for training
        train_transforms = [transforms.ToTensor()]
        test_transforms = [transforms.ToTensor()]
        for transform in task_config.tranform:
            if transform == 'flip':
                train_transforms.append(transforms.RandomHorizontalFlip())
            elif transform == 'crop':
                train_transforms.append(transforms.RandomCrop(32, padding=4))
            elif transform == 'norm':
                train_transforms.append(transforms.Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261)))
                test_transforms.append(transforms.Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261)))
            else:
                raise ValueError(f"Transformation type {transform} is not supported")

        print(train_transforms)
        not_aug_cifar_train_transform = transforms.Compose([transforms.ToTensor(), 
                                                       transforms.Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261))])
        # Compose the transforms
        cifar_train_transform = transforms.Compose(train_transforms)
        cifar_test_transform = transforms.Compose(test_transforms)
        
        train_dataset = datasets.CIFAR10(
            root='./data', train=True, transform=cifar_train_transform, download=True
        )
        not_aug_train_dataset = datasets.CIFAR10(
            root='./data', train=True, transform=not_aug_cifar_train_transform, download=True
        )
        test_dataset = datasets.CIFAR10(
            root='./data', train=False, transform=cifar_test_transform, download=True
        )

        input_dim = task_config.input_shape
        num_classes = 10

        def get_task_dataset(task_id):
            task_id, reoccured = manage_task_id(task_id, task_config.reoccurring_type, task_config.reoccurring_freq)
            random_train_dataset = RandomLabelDataset(
                train_dataset,
                seed=args.seed + task_id,
                num_samples=task_config.limit
            )
            random_not_aug_train_dataset = RandomLabelDataset(
                not_aug_train_dataset,
                seed=args.seed + task_id,
                num_samples=task_config.limit
            )
            random_test_dataset = RandomLabelDataset(test_dataset, seed=args.seed + task_id)
            return random_train_dataset, random_test_dataset, random_not_aug_train_dataset, reoccured
        
    elif task_config.benchmark == "random_label_cifar100":
        # Build a list of transforms for training
        train_transforms = [transforms.ToTensor()]
        test_transforms = [transforms.ToTensor()]
        for transform in task_config.tranform:
            if transform == 'flip':
                train_transforms.append(transforms.RandomHorizontalFlip())
            elif transform == 'crop':
                train_transforms.append(transforms.RandomCrop(32, padding=4))
            elif transform == 'norm':
                train_transforms.append(transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)))
                test_transforms.append(transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)))
            else:
                raise ValueError(f"Transformation type {transform} is not supported")

        print(train_transforms)
        not_aug_cifar_train_transform = transforms.Compose([transforms.ToTensor(), 
                                                       transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761))])
        # Compose the transforms
        cifar_train_transform = transforms.Compose(train_transforms)
        cifar_test_transform = transforms.Compose(test_transforms)
        
        train_dataset = datasets.CIFAR100(
            root='./data', train=True, transform=cifar_train_transform, download=True
        )
        not_aug_train_dataset = datasets.CIFAR100(
            root='./data', train=True, transform=not_aug_cifar_train_transform, download=True
        )
        test_dataset = datasets.CIFAR100(
            root='./data', train=False, transform=cifar_test_transform, download=True
        )

        input_dim = task_config.input_shape
        num_classes = 100

        def get_task_dataset(task_id):
            task_id, reoccured = manage_task_id(task_id, task_config.reoccurring_type, task_config.reoccurring_freq)
            random_train_dataset = RandomLabelDataset(
                train_dataset,
                seed=args.seed + task_id,
                num_samples=task_config.limit
            )
            random_not_aug_train_dataset = RandomLabelDataset(
                not_aug_train_dataset,
                seed=args.seed + task_id,
                num_samples=task_config.limit
            )
            random_test_dataset = RandomLabelDataset(test_dataset, seed=args.seed + task_id)
            return random_train_dataset, random_test_dataset, random_not_aug_train_dataset, reoccured
    
    elif task_config.benchmark == "continual_cifar100":
        cifar_continual = CIFAR100Continual(
            root='./data',
            samples_per_class=task_config.samples_per_class,
            test_samples_per_class=task_config.test_samples_per_class,
            tranform=task_config.tranform,
            seed=args.seed
        )
        input_dim = task_config.input_shape
        num_classes = 100
        def get_task_dataset(task_id):
            task_id, reoccured = manage_task_id(task_id, task_config.reoccurring_type, task_config.reoccurring_freq)
            train_dataset_cifar100, test_dataset_cifar100, not_aug_train_dataset_cifar100, task_classes = cifar_continual.get_task_dataset(task_id = task_id, preserve_labels=not use_hat)
            class_names = cifar_continual.get_class_names(task_classes)
            #print(f"Classes in this task: {class_names}")
            return train_dataset_cifar100, test_dataset_cifar100, not_aug_train_dataset_cifar100, reoccured
    
    elif task_config.benchmark == "continual_cifar10":
        cifar_continual = CIFAR10Continual(
            root='./data',
            samples_per_class=task_config.samples_per_class,
            test_samples_per_class=task_config.test_samples_per_class,
            tranform=task_config.tranform,
            seed=args.seed
        )
        input_dim = task_config.input_shape
        num_classes = 10
        def get_task_dataset(task_id):
            task_id, reoccured = manage_task_id(task_id, task_config.reoccurring_type, task_config.reoccurring_freq)
            train_dataset_cifar10, test_dataset_cifar10, not_aug_train_dataset_cifar10, task_classes = cifar_continual.get_task_dataset(task_id = task_id, preserve_labels=not use_hat)
            class_names = cifar_continual.get_class_names(task_classes)
            #print(f"Classes in this task: {class_names}")
            return train_dataset_cifar10, test_dataset_cifar10, not_aug_train_dataset_cifar10, reoccured

    elif task_config.benchmark == "continual_imagenet":
        input_dim = task_config.input_shape # RGB images with 32x32 resolution
        num_classes = task_config.num_classes  # Each task distinguishes between two classes
        
        # Load Tiny ImageNet dataset with automatic download if needed
        tiny_imagenet_train, tiny_imagenet_val, class_to_idx = load_tiny_imagenet('./data/tiny-imagenet')
        
        # Total number of classes in Tiny ImageNet
        total_classes = len(class_to_idx)
        print(f"Total available classes: {total_classes}")
        train_transforms = [transforms.Resize((32, 32))]
        test_transforms = [transforms.Resize((32, 32))]
        if isinstance(tiny_imagenet_train[0][0], torch.Tensor):
            train_transforms.append(transforms.ToTensor())
            test_transforms.append(transforms.ToTensor())
            
        for transf in task_config.tranform:
            if transf == 'norm':
                train_transforms.append(transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]))
                test_transforms.append(transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]))
            elif transf == 'flip':
                train_transforms.append(transforms.RandomHorizontalFlip())
            elif transf == 'crop':
                train_transforms.append(transforms.RandomCrop(32, padding=4))
        
        print(train_transforms)

        not_aug_train_transforms = transforms.Compose([transforms.ToTensor(), 
                                                       transforms.Resize((32, 32)),
                                                       transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])

        train_transforms = transforms.Compose(train_transforms)
        test_transforms = transforms.Compose(test_transforms)
        
        # Store which classes have been used in tasks
        used_classes = set()
        select_tasks = []

        
        def get_task_dataset(task_id, bwt=False):
            task_id, reoccured = manage_task_id(task_id, task_config.reoccurring_type, task_config.reoccurring_freq)
            nonlocal used_classes  # Access the outer scope variable
            nonlocal select_tasks
            
            # Deterministically set the random seed based on task_id for reproducibility
            random.seed(args.seed + task_id)
            np.random.seed(args.seed + task_id)
            
            # Select two random classes that haven't been used before
            if bwt:
                selected_classes = select_tasks[task_id]
            else:
                available_classes = list(set(range(total_classes)) - used_classes)
                if len(available_classes) < 2:
                    used_classes = set()
                available_classes = list(set(range(total_classes)) - used_classes)
                
                # Randomly select two classes
                selected_classes = sorted(random.sample(available_classes, 2))
                select_tasks.append(selected_classes)
                
                # Add selected classes to used_classes
                for cls in selected_classes:
                    used_classes.add(cls)
            
            print(f"Task {task_id}: Selected classes {selected_classes}")
            
            # Filter datasets to only include selected classes
            train_indices = [i for i, label in enumerate(tiny_imagenet_train.labels) 
                            if label in selected_classes]
            val_indices = [i for i, label in enumerate(tiny_imagenet_val.labels) 
                            if label in selected_classes]
            
            # Ensure we have exactly 600 samples per class for training
            train_indices_class0 = [i for i in train_indices 
                                if tiny_imagenet_train.labels[i] == selected_classes[0]][:task_config.limit]  
            train_indices_class1 = [i for i in train_indices 
                                if tiny_imagenet_train.labels[i] == selected_classes[1]][:task_config.limit]
            
            # Combine indices
            train_indices = train_indices_class0 + train_indices_class1
            
            # Create subsets
            train_dataset = Subset(tiny_imagenet_train, train_indices)
            test_dataset = Subset(tiny_imagenet_val, val_indices)
            
            # Define a custom dataset that applies transformations and remaps classes to 0 and 1
            class TaskDataset(Dataset):
                def __init__(self, subset, class_mapping, transform=None):
                    self.subset = subset
                    self.class_mapping = class_mapping
                    self.transform = transform
                    
                def __len__(self):
                    return len(self.subset)
                        
                def __getitem__(self, idx):
                    img, label = self.subset[idx]
                    # Remap label to 0 or 1
                    if use_hat:
                        new_label = self.class_mapping[label]
                    else:
                        new_label = label
                    
                    # Check if img is already a tensor, if so, don't apply transform.ToTensor() again
                    if isinstance(img, torch.Tensor):
                        # Just apply any resizing and normalization
                        if self.transform:
                            # Create a transform that doesn't include ToTensor
                            transform_without_totensor = transforms.Compose([
                                t for t in self.transform.transforms 
                                if not isinstance(t, transforms.ToTensor)
                            ])
                            img = transform_without_totensor(img)
                    else:
                        # Apply the full transform pipeline
                        if self.transform:
                            img = self.transform(img)
                            
                    return img, new_label
            
            # Create mapping from original class indices to 0 and 1
            class_mapping = {selected_classes[0]: 0, selected_classes[1]: 1}
            
            # Apply transforms and class mapping
            train_dataset_imagenet = TaskDataset(train_dataset, class_mapping, train_transforms)
            not_aug_train_dataset_imagenet = TaskDataset(train_dataset, class_mapping, not_aug_train_transforms)
            test_dataset_imagenet = TaskDataset(test_dataset, class_mapping, test_transforms)
            
            print(f"Task {task_id}: Created datasets with {len(train_dataset_imagenet)} training and {len(test_dataset_imagenet)} testing samples")
            
            return train_dataset_imagenet, test_dataset_imagenet, not_aug_train_dataset_imagenet, reoccured
        
    elif task_config.benchmark == "new_continual_imagenet":
        input_dim = task_config.input_shape # RGB images with 32x32 resolution
        num_classes = task_config.num_classes  # Each task distinguishes between two classes
        
        # Load Tiny ImageNet dataset with automatic download if needed
        tiny_imagenet_train, tiny_imagenet_val, class_to_idx = load_tiny_imagenet('./data/tiny-imagenet')
        
        # Total number of classes in Tiny ImageNet
        total_classes = len(class_to_idx)
        print(f"Total available classes: {total_classes}")
        train_transforms = [transforms.Resize((32, 32))]
        test_transforms = [transforms.Resize((32, 32))]
        if isinstance(tiny_imagenet_train[0][0], torch.Tensor):
            train_transforms.append(transforms.ToTensor())
            test_transforms.append(transforms.ToTensor())
            
        for transf in task_config.tranform:
            if transf == 'norm':
                train_transforms.append(transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]))
                test_transforms.append(transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]))
            elif transf == 'flip':
                train_transforms.append(transforms.RandomHorizontalFlip())
            elif transf == 'crop':
                train_transforms.append(transforms.RandomCrop(32, padding=4))
        
        print(train_transforms)
        not_aug_train_transforms = transforms.Compose([transforms.Resize((32, 32)),
                                                       transforms.ToTensor(),
                                                       transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
        train_transforms = transforms.Compose(train_transforms)
        test_transforms = transforms.Compose(test_transforms)
        
        # Store which classes have been used in tasks
        used_classes = set()
        select_tasks = []

        
        def get_task_dataset(task_id, bwt=False):
            task_id, reoccured = manage_task_id(task_id, task_config.reoccurring_type, task_config.reoccurring_freq)
            nonlocal used_classes  # Access the outer scope variable
            nonlocal select_tasks
            
            # Deterministically set the random seed based on task_id for reproducibility
            random.seed(args.seed + task_id)
            np.random.seed(args.seed + task_id)
            
            # Select two random classes that haven't been used before
            if bwt:
                selected_classes = select_tasks[task_id]
            else:
                available_classes = list(set(range(total_classes)) - used_classes)
                if len(available_classes) < 20:
                    used_classes = set()
                available_classes = list(set(range(total_classes)) - used_classes)
                
                # Randomly select two classes
                selected_classes = sorted(random.sample(available_classes, 20))
                select_tasks.append(selected_classes)
                
                # Add selected classes to used_classes
                for cls in selected_classes:
                    used_classes.add(cls)
            
            print(f"Task {task_id}: Selected classes {selected_classes}")
            
            # Filter datasets to only include selected classes
            train_indices = [i for i, label in enumerate(tiny_imagenet_train.labels) 
                            if label in selected_classes]
            val_indices = [i for i, label in enumerate(tiny_imagenet_val.labels) 
                            if label in selected_classes]
            
            # Ensure we have exactly 600 samples per class for training
            train_indices_class0 = [i for i in train_indices 
                                if tiny_imagenet_train.labels[i] == selected_classes[0]][:task_config.limit]
            train_indices_class1 = [i for i in train_indices 
                                if tiny_imagenet_train.labels[i] == selected_classes[1]][:task_config.limit]
            
            # Combine indices
            train_indices = train_indices_class0 + train_indices_class1
            
            # Create subsets
            train_dataset = Subset(tiny_imagenet_train, train_indices)
            test_dataset = Subset(tiny_imagenet_val, val_indices)
            
            # Define a custom dataset that applies transformations and remaps classes to 0 and 1
            class TaskDataset(Dataset):
                def __init__(self, subset, class_mapping, transform=None):
                    self.subset = subset
                    self.class_mapping = class_mapping
                    self.transform = transform
                    
                def __len__(self):
                    return len(self.subset)
                        
                def __getitem__(self, idx):
                    img, label = self.subset[idx]
                    # Remap label to 0 or 1
                    if use_hat:
                        new_label = self.class_mapping[label]
                    else:
                        new_label = label
                    
                    # Check if img is already a tensor, if so, don't apply transform.ToTensor() again
                    if isinstance(img, torch.Tensor):
                        # Just apply any resizing and normalization
                        if self.transform:
                            # Create a transform that doesn't include ToTensor
                            transform_without_totensor = transforms.Compose([
                                t for t in self.transform.transforms 
                                if not isinstance(t, transforms.ToTensor)
                            ])
                            img = transform_without_totensor(img)
                    else:
                        # Apply the full transform pipeline
                        if self.transform:
                            img = self.transform(img)
                            
                    return img, new_label
            
            # Create mapping from original class indices to 0 and 1
            class_mapping = {}
            for index in range(len(selected_classes)):
                class_mapping[selected_classes[index]] = index
            #class_mapping = {selected_classes[0]: 0, selected_classes[1]: 1}
            
            # Apply transforms and class mapping
            train_dataset_imagenet = TaskDataset(train_dataset, class_mapping, train_transforms)
            not_aug_train_dataset_imagenet = TaskDataset(train_dataset, class_mapping, not_aug_train_transforms)
            test_dataset_imagenet = TaskDataset(test_dataset, class_mapping, test_transforms)
            
            print(f"Task {task_id}: Created datasets with {len(train_dataset_imagenet)} training and {len(test_dataset_imagenet)} testing samples")
            
            return train_dataset_imagenet, test_dataset_imagenet, not_aug_train_dataset_imagenet, reoccured
    else:
        raise ValueError(f"Dataset {task_config.benchmark} is not supported")

    return get_task_dataset, input_dim, num_classes