import os
from dataset.augmentations import CifarImageTransforms
import torchvision.datasets as datasets
from dataset.augmentations import TinyImgTransform, ImageNetRTransform, DomainNetTransform
from dataset.tinyimg import TinyImagenet
from dataset.imagenet_r import ImageNetR
from dataset.domainnet import DomainNet
import torch.utils.data as data
import threading


def create_loader(train_dataset, test_dataset, train_class_indices, test_class_indices, classes_subset, batch_size, result_dict, index):
    """Create subsets and corresponding DataLoaders for train and test datasets."""
    train_indices = [idx for c in classes_subset for idx in train_class_indices[c] if idx < len(train_dataset)]
    test_indices = [idx for c in classes_subset for idx in test_class_indices[c] if idx < len(test_dataset)]

    train_subset = data.Subset(train_dataset, train_indices)
    test_subset = data.Subset(test_dataset, test_indices)

    train_loader = data.DataLoader(train_subset, batch_size=batch_size, shuffle=True)
    test_loader = data.DataLoader(test_subset, batch_size=batch_size, shuffle=False)

    # Store the results in the result_dict
    result_dict[index] = (train_loader, test_loader)


def create_sequential_dataloaders(train_dataset, test_dataset, n_classes: list, total_classes: int, batch_size: int):
    """Given an entire dataset, create sequential len(n_classes) tasks, and n_classes[i] classes per task"""
    train_class_indices = {}
    test_class_indices = {}

    for i in range(total_classes):
        train_class_indices[i] = [idx for idx, label in enumerate(train_dataset.targets) if label == i]
        test_class_indices[i] = [idx for idx, label in enumerate(test_dataset.targets) if label == i]

    train_dataloaders = []
    test_dataloaders = []

    n_classes = [0] + n_classes
    threads = []
    result_dict = {}

    for k in range(len(n_classes) - 1):
        classes_subset = list(range(sum(n_classes[:k + 1]), sum(n_classes[:k + 2])))

        # Create a thread to execute create_loader
        thread = threading.Thread(target=create_loader, args=(train_dataset, test_dataset, train_class_indices, test_class_indices, classes_subset, batch_size, result_dict, k))
        threads.append(thread)
        thread.start()

    for thread in threads:
        thread.join()

    # Extract results from the result_dict
    for k in range(len(n_classes) - 1):
        train_loader, test_loader = result_dict.get(k)
        train_dataloaders.append(train_loader)
        test_dataloaders.append(test_loader)

    return train_dataloaders, test_dataloaders


def load_dataset(args):
    """Create training and test datasets based on dataset name."""
    root = f"{args.data_root}/{args.dataset}_{args.n_classes_per_task}"
    os.makedirs(root, exist_ok=True)

    dataset_mapping = {
        "cifar10": datasets.CIFAR10,
        "cifar100": datasets.CIFAR100,
        "tinyimg": TinyImagenet,
        "imagenet-r": ImageNetR,
        "domainnet": DomainNet
    }

    transforms_mapping = {
        "cifar10": CifarImageTransforms(args.image_size),
        "cifar100": CifarImageTransforms(args.image_size),
        "tinyimg": TinyImgTransform(args.image_size),
        "imagenet-r": ImageNetRTransform(args.image_size),
        "domainnet": DomainNetTransform(args.image_size)
    }

    if args.dataset in dataset_mapping:
        transform = transforms_mapping[args.dataset]
        trainset = dataset_mapping[args.dataset](root=root, train=True, download=True, transform=transform.train_transform)
        testset = dataset_mapping[args.dataset](root=root, train=False, download=True, transform=transform.test_transform)
    else:
        raise ValueError(f"Not implemented for dataset: {args.dataset}")

    return trainset, testset