import torch
from torchvision import datasets, transforms as T


def generate_task_classes(task_id, num_tasks, root='/root/cifar10'):
    """
    Args:
        task_id (int): current task id
        num_tasks (int): total number of tasks
        root (str): dataset root path

    Returns:
        train_df: training data for the current task
        test_df: test data for the current task
    """
    if num_tasks == 2:
        task_classes = {
            0: [0, 1, 2, 3, 4],
            1: [5, 6, 7, 8, 9]
        }
    elif num_tasks == 5:
        task_classes = {
            0: [0, 1],
            1: [2, 3],
            2: [4, 5],
            3: [6, 7],
            4: [8, 9]
        }
    else:
        raise ValueError("Number of tasks needs to be checked.")
    classes = task_classes[task_id]
    transform = T.Compose([
        T.ToTensor(),
        T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
    ])
    full_train = datasets.CIFAR10(root=root, train=True, download=True, transform=transform)
    full_test = datasets.CIFAR10(root=root, train=False, download=True, transform=transform)
    def filter_by_class(dataset, class_list):
        indices = [i for i, (_, label) in enumerate(dataset) if label in class_list]
        return torch.utils.data.Subset(dataset, indices)
    train_df = filter_by_class(full_train, classes)
    test_df = filter_by_class(full_test, classes)
    return train_df, test_df



def load_all_tasks(num_tasks, root='/root/cifar10'):
    """
    Load data division for all tasks
    Args:
        num_tasks (int): total number of tasks
        root (str): dataset Path
    Returns:
        list: list consisting of (train_df, test_df) tuples for each task
    """
    tasks = [generate_task_classes(i, num_tasks, root=root) for i in range(num_tasks)]
    return tasks
