import numpy as np
import torch
from dataloader.sampler import CategoriesSampler
# from torch.utils.data import Dataset
from torchvision import transforms
from augmentations.constrained_cropping import CustomMultiCropDataset, CustomMultiCropping
from torch.utils.data import DataLoader, Subset, TensorDataset
from torch.utils.data.sampler import SubsetRandomSampler
from dataloader.autoaugment import CIFAR10Policy, Cutout, ImageNetPolicy


def set_up_datasets(args):
    if args.dataset == 'cifar100':
        import dataloader.cifar100.cifar as Dataset
        args.base_class = 60
        args.num_classes = 100
        args.way = 5
        args.shot = args.incremental_shot
        # args.shot = 5
        args.sessions = 9
        args.fake_base_class = 15
        args.fake_way = 5
        args.fake_shot = 5

    if args.dataset == 'cub200':
        import dataloader.cub200.cub200 as Dataset
        args.base_class = 100
        args.num_classes = 200
        args.way = 10
        args.shot = args.incremental_shot
        # args.shot = 5
        args.sessions = 11
        args.fake_base_class = 25
        args.fake_way = 10
        args.fake_shot = 5

    if args.dataset == 'mini_imagenet':
        import dataloader.miniimagenet.miniimagenet as Dataset
        args.base_class = 60
        args.num_classes = 100
        args.way = 5
        args.shot = args.incremental_shot
        # args.shot = 5
        args.sessions = 9
        args.fake_base_class = 15
        args.fake_way = 5
        args.fake_shot = 5

    args.Dataset = Dataset
    return args


def appendKBaseExemplars(dataset, args, nclass):
    """
        Take only labels from args.base_class and in data self.data append the single exemplar
    """
    if args.dataset == "cifar100":
        # Get dataset indices under base_class
        for i in range(nclass):
            ind_cl = np.where(i == dataset.targets_all)[0]

            # Choose top 5 from ind_cl and append into data_tmp (done to stay consistent across experiments)
            ind_cl = ind_cl[:args.exemplars_count]

            dataset.data = np.vstack((dataset.data, dataset.data_all[ind_cl]))
            dataset.targets = np.hstack((dataset.targets, dataset.targets_all[ind_cl]))
        return

    label2data = {}
    for k, v in dataset.data2label.items():
        if v < nclass:
            if v not in label2data: label2data[v] = []
            label2data[v].append(k)

    # To maintain simplicity and the reduce added complexity we always sample the first K exemplars from the base class.
    # This should ideally not introduce any biases
    data_tmp = []
    targets_tmp = []

    for i in range(nclass):
        for k in range(args.exemplars_count):
            data_tmp.append(label2data[i][k])
            targets_tmp.append(i)

    dataset.data.extend(data_tmp)
    dataset.targets.extend(targets_tmp)

    return data_tmp, targets_tmp


class TwoCropTransform:
    """Create two crops of the same image"""

    def __init__(self, transform):
        self.transform = transform

    def __call__(self, x):
        return [self.transform(x), self.transform(x)]


class MultiCropTransform:
    """Create two crops of the same image"""

    def __init__(self, transform, n_views=2):
        self.transform = transform
        self.n_views = n_views

    def __call__(self, x):
        out = []
        for i in range(self.n_views):
            out.append(self.transform(x))
        return out


def get_supcon_dataloader(args):
    txt_path = "data/index_list/" + args.dataset + "/session_" + str(0 + 1) + '.txt'
    class_index = np.arange(args.base_class)  # test on a small dataset for debugging purpose
    if args.dataset == 'cifar100':
        normalize = transforms.Normalize(mean=[0.507, 0.487, 0.441], std=[0.267, 0.256, 0.276])
        if args.rand_aug_sup_con:
            train_transform = transforms.Compose([
                transforms.RandomResizedCrop(size=32, scale=(args.min_crop_scale, 1.)),
                transforms.RandomHorizontalFlip(),
                CIFAR10Policy(),
                transforms.RandomGrayscale(p=0.2),
                transforms.ToTensor(),
                Cutout(n_holes=1, length=16),
                normalize,
            ])
        else:
            train_transform = transforms.Compose([
                transforms.RandomResizedCrop(size=32, scale=(args.min_crop_scale, 1.)),
                transforms.RandomHorizontalFlip(),
                transforms.RandomApply([
                    transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)
                ], p=0.8),
                transforms.RandomGrayscale(p=0.2),
                transforms.ToTensor(),
                normalize,
            ])
        test_transform = transforms.Compose([
            transforms.ToTensor(),
            normalize
        ])
        trainset = args.Dataset.CIFAR100(root=args.dataroot, train=True, download=False,
                                         index=class_index, base_sess=True,
                                         transform=MultiCropTransform(train_transform))
        testset = args.Dataset.CIFAR100(root=args.dataroot, train=False, download=False,
                                        index=class_index, base_sess=False)

    if args.dataset == 'cub200':
        # normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        normalize = transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))

        if args.rand_aug_sup_con:
            train_transform = transforms.Compose([
                transforms.Resize(256),
                transforms.RandomResizedCrop(224),
                transforms.RandomHorizontalFlip(),
                ImageNetPolicy(),
                transforms.ToTensor(),
                normalize
            ])
        else:
            train_transform = transforms.Compose([
                transforms.Resize(256),
                transforms.RandomResizedCrop(size=224, scale=(args.min_crop_scale, 1.)),
                transforms.RandomHorizontalFlip(),
                transforms.RandomApply([
                    transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)
                ], p=0.8),
                transforms.RandomGrayscale(p=0.2),
                transforms.ToTensor(),
                normalize,
            ])
        trainset = args.Dataset.CUB200(root=args.dataroot, train=True,
                                       index=class_index, base_sess=True,
                                       dino_transform=None, transform=MultiCropTransform(train_transform))
        testset = args.Dataset.CUB200(root=args.dataroot, train=False, index=class_index)

    if args.dataset == 'mini_imagenet':
        normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        if args.rand_aug_sup_con:
            train_transform = transforms.Compose([
                transforms.RandAugment(num_ops=3, magnitude=11),
                transforms.RandomResizedCrop(size=84, scale=(args.min_crop_scale, 1.)),
                transforms.ToTensor(),
                normalize,
            ])
        else:
            train_transform = transforms.Compose([
                transforms.RandomResizedCrop(size=84, scale=(args.min_crop_scale, 1.)),
                transforms.RandomHorizontalFlip(),
                transforms.RandomApply([
                    transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)
                ], p=0.8),
                transforms.RandomGrayscale(p=0.2),
                transforms.ToTensor(),
                normalize,
            ])
        trainset = args.Dataset.MiniImageNet(root=args.dataroot, train=True,
                                             index=class_index, base_sess=True,
                                             transform=TwoCropTransform(train_transform))
        testset = args.Dataset.MiniImageNet(root=args.dataroot, train=False, index=class_index)

    trainloader = torch.utils.data.DataLoader(dataset=trainset, batch_size=args.batch_size_sup_con, shuffle=True,
                                              num_workers=8, pin_memory=True, drop_last=args.drop_last_batch)
    testloader = torch.utils.data.DataLoader(
        dataset=testset, batch_size=args.test_batch_size, shuffle=False, num_workers=8, pin_memory=True)

    return trainset, trainloader, testloader


def get_supcon_joint_dataloader(args, session):
    txt_path = "data/index_list/" + args.dataset + "/session_" + str(session + 1) + '.txt'
    class_index = np.arange(args.base_class)  # test on a small dataset for debugging purpose
    if args.dataset == 'cifar100':
        normalize = transforms.Normalize(mean=[0.507, 0.487, 0.441], std=[0.267, 0.256, 0.276])
        if args.rand_aug_sup_con:
            train_transform = transforms.Compose([
                transforms.RandomResizedCrop(size=32, scale=(args.min_crop_scale, 1.)),
                transforms.RandomHorizontalFlip(),
                CIFAR10Policy(),
                transforms.RandomGrayscale(p=0.2),
                transforms.ToTensor(),
                Cutout(n_holes=1, length=16),
                normalize,
            ])
        else:
            train_transform = transforms.Compose([
                transforms.RandomResizedCrop(size=32, scale=(args.min_crop_scale, 1.)),
                transforms.RandomHorizontalFlip(),
                transforms.RandomApply([
                    transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)
                ], p=args.prob_color_jitter),
                transforms.RandomGrayscale(p=0.2),
                transforms.ToTensor(),
                normalize,
            ])
        test_transform = transforms.Compose([
            transforms.ToTensor(),
            normalize
        ])
        class_index = open(txt_path).read().splitlines()
        dataset_class = args.Dataset.CIFAR100

        base_aug_mag = 0
        trainset = dataset_class(root=args.dataroot, train=True, download=False,
                                 index=class_index, base_sess=False, keep_all=True,
                                 transform=MultiCropTransform(train_transform, 2),
                                 base_aug_mag=base_aug_mag)

    if args.dataset == 'cub200':
        normalize = transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
        if args.rand_aug_sup_con:
            train_transform = transforms.Compose([
                transforms.Resize(256),
                transforms.RandomResizedCrop(224),
                transforms.RandomHorizontalFlip(),
                ImageNetPolicy(),
                transforms.ToTensor(),
                normalize
            ])
        else:
            train_transform = transforms.Compose([
                transforms.Resize(256),
                transforms.RandomResizedCrop(size=224, scale=(args.min_crop_scale, 1.)),
                transforms.RandomHorizontalFlip(),
                transforms.RandomApply([
                    transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)
                ], p=args.prob_color_jitter),
                transforms.RandomGrayscale(p=0.2),
                transforms.ToTensor(),
                normalize,
            ])
        dataset_class = args.Dataset.CUB200
        trainset = dataset_class(root=args.dataroot, train=True,
                                 index_path=txt_path, transform=MultiCropTransform(train_transform, 2))

    if args.dataset == 'mini_imagenet':
        normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        if args.rand_aug_sup_con:
            train_transform = transforms.Compose([
                transforms.RandomResizedCrop(size=84, scale=(args.min_crop_scale, 1.)),
                transforms.RandAugment(num_ops=3),
                transforms.ToTensor(),
                normalize,
            ])
        else:
            train_transform = transforms.Compose([
                transforms.RandomResizedCrop(size=84, scale=(args.min_crop_scale, 1.)),
                transforms.RandomHorizontalFlip(),
                transforms.RandomApply([
                    transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)
                ], p=args.prob_color_jitter),
                transforms.RandomGrayscale(p=0.2),
                transforms.ToTensor(),
                normalize,
            ])
        dataset_class = args.Dataset.MiniImageNet
        trainset = dataset_class(root=args.dataroot, train=True,
                                 index_path=txt_path, transform=MultiCropTransform(train_transform, 2))

    nclass = args.base_class

    # Now for each previous session i.e. session > 0 and session < curr_session
    # load trainset using the index files. And append data and labels from this dataset to the current one
    # Add the ability to choose the number of exemplars from previous sessions
    for inter_ix in range(1, session):  # is = intermediate_sessino
        txt_path = "data/index_list/" + args.dataset + "/session_" + str(inter_ix + 1) + '.txt'
        if args.dataset == "cifar100":
            class_index = open(txt_path).read().splitlines()
            inter_set = dataset_class(root=args.dataroot, train=True, download=False, index=class_index,
                                      base_sess=False)  # Get data from current index
            trainset.data = np.vstack((trainset.data, inter_set.data))
            trainset.targets = np.hstack((trainset.targets, inter_set.targets))
        else:
            inter_set = dataset_class(root=args.dataroot, train=True, index_path=txt_path,
                                      base_sess=False)  # Get data from current index

            if args.exemplars_count != args.shot:
                # Exemplar Control: Append the new data from the previous intermeidate sessions to the current dataset
                inter_targets = np.array(inter_set.targets)
                for i in np.unique(inter_targets):
                    ixs = np.where(inter_targets == i)[0]
                    selected_ixs = list(ixs[:args.exemplars_count])
                    for j in selected_ixs:
                        trainset.data.append(inter_set.data[j])
                        trainset.targets.append(inter_set.targets[j])
            else:
                trainset.data.extend(inter_set.data)
                trainset.targets.extend(inter_set.targets)

    # Append the base classes to the current dataset
    appendKBaseExemplars(trainset, args, nclass)

    if args.batch_size_joint == 0:
        batch_size_new = trainset.__len__()
        trainloader = torch.utils.data.DataLoader(dataset=trainset, batch_size=batch_size_new, shuffle=True,
                                                  # <<< TODO: Shuffled. Check if this is problematic
                                                  num_workers=args.num_workers, pin_memory=True,
                                                  drop_last=args.drop_last_batch)
    else:
        trainloader = torch.utils.data.DataLoader(dataset=trainset, batch_size=args.batch_size_joint, shuffle=True,
                                                  num_workers=args.num_workers, pin_memory=True,
                                                  drop_last=args.drop_last_batch)

    return trainset, trainloader


def get_transform(args):
    if args.dataset == 'cifar100':
        normalize = transforms.Normalize(mean=[0.5071, 0.4867, 0.4408],
                                         std=[0.2675, 0.2565, 0.2761])
    if args.dataset == 'cub200':
        normalize = transforms.Normalize(mean=[0.4914, 0.4822, 0.4465],
                                         std=[0.2023, 0.1994, 0.2010])
    if args.dataset == 'mini_imagenet':
        normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                         std=[0.229, 0.224, 0.225])

    assert (len(args.size_crops) == 2)
    crop_transform = CustomMultiCropping(size_large=args.size_crops[0],
                                         scale_large=(args.min_scale_crops[0], args.max_scale_crops[0]),
                                         size_small=args.size_crops[1],
                                         scale_small=(args.min_scale_crops[1], args.max_scale_crops[1]),
                                         N_large=args.num_crops[0], N_small=args.num_crops[1],
                                         condition_small_crops_on_key=args.constrained_cropping)

    if len(args.auto_augment) == 0:
        print('No auto augment - Apply regular moco v2 as secondary transform')
        secondary_transform = transforms.Compose([
            transforms.RandomHorizontalFlip(),
            transforms.RandomApply([
                transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)
            ], p=0.8),
            transforms.RandomGrayscale(p=0.2),
            #             transforms.RandomApply([GaussianBlur([.1, 2.])], p=0.5),
            transforms.ToTensor(),
            normalize])

    else:
        from utils.auto_augment.auto_augment import AutoAugment
        from utils.auto_augment.random_choice import RandomChoice
        print('Auto augment - Apply custom auto-augment strategy')
        counter = 0
        secondary_transform = []

        for i in range(len(args.size_crops)):
            for j in range(args.num_crops[i]):
                if not counter in set(args.auto_augment):
                    print('Crop {} - Apply regular secondary transform'.format(counter))
                    secondary_transform.extend([transforms.Compose([
                        transforms.RandomHorizontalFlip(),
                        transforms.RandomApply([
                            transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)  # not strengthened
                        ], p=0.8),
                        transforms.RandomGrayscale(p=0.2),
                        #                         transforms.RandomApply([GaussianBlur([.1, 2.])], p=0.5),
                        transforms.ToTensor(),
                        normalize])])

                else:
                    print('Crop {} - Apply auto-augment/regular secondary transform'.format(counter))
                    trans1 = transforms.Compose([
                        transforms.RandomHorizontalFlip(),
                        AutoAugment(),
                        transforms.ToTensor(),
                        normalize])

                    trans2 = transforms.Compose([
                        transforms.RandomHorizontalFlip(),
                        transforms.RandomApply([
                            transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)  # not strengthened
                        ], p=0.8),
                        transforms.RandomGrayscale(p=0.2),
                        #                         transforms.RandomApply([GaussianBlur([.1, 2.])], p=0.5),
                        transforms.ToTensor(),
                        normalize])

                    secondary_transform.extend([RandomChoice([trans1, trans2])])

                counter += 1
    return crop_transform, secondary_transform


def get_dataloader(args, session):
    if session == 0:
        trainset, trainloader, testloader = get_base_dataloader(args)
    else:
        trainset, trainloader, testloader = get_new_dataloader(args, session)
    return trainset, trainloader, testloader


def get_new_dataloader(args, session):
    # crop_transform, secondary_transform = get_transform(args)
    # txt_path = "data/index_list/" + args.dataset + "/session_" + str(session + 1) + '.txt'
    txt_path = f"data/index_list/" + args.dataset + "/session_" + str(session + 1) + '.txt'
    class_index = open(txt_path).read().splitlines()
    if args.dataset == 'cifar100':
        # class_index = open(txt_path).read().splitlines()
        # trainset = args.Dataset.CIFAR100(root=args.dataroot, train=True, download=False, index=class_index,
        #                                  base_sess=False, crop_transform=crop_transform,
        #                                  secondary_transform=secondary_transform)
        trainset = args.Dataset.CIFAR100(root=args.dataroot, train=True, download=False,
                                         index=class_index, base_sess=False)
    if args.dataset == 'cub200':
        # trainset = args.Dataset.CUB200(root=args.dataroot, train=True, index_path=txt_path, base_sess=False,
        #                                crop_transform=crop_transform, secondary_transform=secondary_transform)
        trainset = args.Dataset.CUB200(root=args.dataroot, train=True,
                                       index_path=txt_path)
    if args.dataset == 'mini_imagenet':
        # trainset = args.Dataset.MiniImageNet(root=args.dataroot, train=True, index_path=txt_path, base_sess=False,
        #                                      crop_transform=crop_transform, secondary_transform=secondary_transform)
        trainset = args.Dataset.MiniImageNet(root=args.dataroot, train=True,
                                             index_path=txt_path)
    if args.batch_size_new == 0:
        batch_size_new = trainset.__len__()
        trainloader = torch.utils.data.DataLoader(dataset=trainset, batch_size=batch_size_new, shuffle=False,
                                                  num_workers=args.num_workers, pin_memory=True)
    else:
        trainloader = torch.utils.data.DataLoader(dataset=trainset, batch_size=args.batch_size_new, shuffle=True,
                                                  num_workers=args.num_workers, pin_memory=True)

    # test on all encountered classes
    class_new = get_session_classes(args, session)

    if args.dataset == 'cifar100':
        testset = args.Dataset.CIFAR100(root=args.dataroot, train=False, download=False,
                                        index=class_new, base_sess=False)
    if args.dataset == 'cub200':
        testset = args.Dataset.CUB200(root=args.dataroot, train=False,
                                      index=class_new)
    if args.dataset == 'mini_imagenet':
        testset = args.Dataset.MiniImageNet(root=args.dataroot, train=False,
                                            index=class_new)

    testloader = torch.utils.data.DataLoader(dataset=testset, batch_size=args.test_batch_size, shuffle=False,
                                             num_workers=args.num_workers, pin_memory=True)

    return trainset, trainloader, testloader


def get_base_dataloader(args):
    # crop_transform, secondary_transform = get_transform(args)
    txt_path = "data/index_list/" + args.dataset + "/session_" + str(0 + 1) + '.txt'
    class_index = np.arange(args.base_class)
    if args.dataset == 'cifar100':
        trainset = args.Dataset.CIFAR100(root=args.dataroot, train=True, download=True,
                                         index=class_index, base_sess=True)

        testset = args.Dataset.CIFAR100(root=args.dataroot, train=False, download=False,
                                        index=class_index, base_sess=True)

    if args.dataset == 'cub200':
        trainset = args.Dataset.CUB200(root=args.dataroot, train=True,
                                       index=class_index, base_sess=True)

        testset = args.Dataset.CUB200(root=args.dataroot, train=False, index=class_index)

    if args.dataset == 'mini_imagenet':
        trainset = args.Dataset.MiniImageNet(root=args.dataroot, train=True,
                                             index=class_index, base_sess=True)

        testset = args.Dataset.MiniImageNet(root=args.dataroot, train=False, index=class_index)

    if args.dataset == 'imagenet100' or args.dataset == 'imagenet1000':
        trainset = args.Dataset.ImageNet(root=args.dataroot, train=True,
                                         index=class_index, base_sess=True)
        testset = args.Dataset.ImageNet(root=args.dataroot, train=False, index=class_index)

    trainloader = torch.utils.data.DataLoader(dataset=trainset, batch_size=args.batch_size_base, shuffle=True,
                                              num_workers=8, pin_memory=True)
    testloader = torch.utils.data.DataLoader(
        dataset=testset, batch_size=args.test_batch_size, shuffle=False, num_workers=8, pin_memory=True)

    return trainset, trainloader, testloader


def get_base_dataloader_new(args):
    crop_transform, secondary_transform = get_transform(args)
    txt_path = "data/index_list/" + args.dataset + "/session_" + str(0 + 1) + '.txt'
    class_index = np.arange(args.base_class)
    if args.dataset == 'cifar100':
        # trainset = args.Dataset.CIFAR100(root=args.dataroot, train=True, download=True,
        #                                  index=class_index, base_sess=True)
        trainset = args.Dataset.CIFAR100(root=args.dataroot, train=True, download=True, index=class_index,
                                         base_sess=True, crop_transform=crop_transform,
                                         secondary_transform=secondary_transform)
        testset = args.Dataset.CIFAR100(root=args.dataroot, train=False, download=False,
                                        index=class_index, base_sess=True)

    if args.dataset == 'cub200':
        # trainset = args.Dataset.CUB200(root=args.dataroot, train=True,
        #                                index=class_index, base_sess=True)
        trainset = args.Dataset.CUB200(root=args.dataroot, train=True, index=class_index, base_sess=True,
                                       crop_transform=crop_transform, secondary_transform=secondary_transform)
        testset = args.Dataset.CUB200(root=args.dataroot, train=False, index=class_index)

    if args.dataset == 'mini_imagenet':
        # trainset = args.Dataset.MiniImageNet(root=args.dataroot, train=True,
        #                                      index=class_index, base_sess=True)
        trainset = args.Dataset.MiniImageNet(root=args.dataroot, train=True, index=class_index, base_sess=True,
                                             crop_transform=crop_transform, secondary_transform=secondary_transform)
        testset = args.Dataset.MiniImageNet(root=args.dataroot, train=False, index=class_index)

    if args.dataset == 'imagenet100' or args.dataset == 'imagenet1000':
        trainset = args.Dataset.ImageNet(root=args.dataroot, train=True,
                                         index=class_index, base_sess=True)
        testset = args.Dataset.ImageNet(root=args.dataroot, train=False, index=class_index)

    trainloader = torch.utils.data.DataLoader(dataset=trainset, batch_size=args.batch_size_base, shuffle=True,
                                              num_workers=8, pin_memory=True)
    testloader = torch.utils.data.DataLoader(
        dataset=testset, batch_size=args.test_batch_size, shuffle=False, num_workers=8, pin_memory=True)

    return trainset, trainloader, testloader


def get_fake_incremental_data(trainset, class_index, args, epoch=0):
    """
    Creates a fake incremental task with num_way classes and num_shot + num_query examples per class.
    """
    # Calculate the number of classes for pretraining and incremental learning
    num_classes = len(class_index)
    half_classes = num_classes // 4
    pretrain_classes = class_index[:half_classes]
    postrain_classes = class_index[half_classes:]
    incremental_classes = class_index[half_classes:]
    # incremental_classes = class_index

    # Create pretraining and postraining subsets and loaders
    pretrain_indices = np.where(np.isin(np.array(trainset.targets), pretrain_classes))[0]
    pretrain_set = Subset(trainset, pretrain_indices)
    pretrain_loader = DataLoader(
        dataset=pretrain_set, batch_size=args.batch_size_base, shuffle=True,
        num_workers=args.num_workers, pin_memory=True
    )

    postrain_indices = np.where(np.isin(np.array(trainset.targets), postrain_classes))[0]
    postrain_set = Subset(trainset, postrain_indices)
    postrain_loader = DataLoader(
        dataset=postrain_set, batch_size=args.batch_size_base, shuffle=True,
        num_workers=args.num_workers, pin_memory=True
    )

    # Settings for incremental learning task
    num_way = 5 if args.dataset in ['cifar100', 'mini_imagenet'] else 10
    num_shot, num_query = 5, 50 if args.dataset in ['cifar100',
                                                    'mini_imagenet'] else 5  # Adjust query number as needed

    # Initialize lists to hold indices for support and query sets
    support_indices = []
    query_indices = []

    # Ensure reproducibility
    np.random.seed(epoch)

    # Iterate over all incremental classes to prepare support and query sets
    for c in class_index:  # Using full class_index for incremental learning
        c_indices = np.where(np.array(trainset.targets) == c)[0]
        np.random.shuffle(c_indices)

        # Strictly partition indices for support and query to avoid overlap
        support_c_indices = c_indices[:num_shot]
        query_c_indices = c_indices[num_shot:num_shot + num_query]

        support_indices.extend(support_c_indices)
        query_indices.extend(query_c_indices)

    # Create and return loaders for support and query sets
    support_set = Subset(trainset, support_indices)
    query_set = Subset(trainset, query_indices)
    support_loader = DataLoader(
        dataset=support_set, batch_size=num_way * num_shot, shuffle=False,
        num_workers=args.num_workers, pin_memory=True
    )
    query_loader = DataLoader(
        dataset=query_set, batch_size=num_way * num_query, shuffle=False,
        num_workers=args.num_workers, pin_memory=True
    )

    return pretrain_loader, postrain_loader, support_loader, query_loader


def get_meta_data(trainset, class_index, args, seed):
    """
    Creates a fake incremental task with num_way classes and num_shot + num_query examples per class.
    """
    # Ensure the same random state for reproducibility if needed
    some_seed = seed
    np.random.seed(some_seed)
    # Calculate the number of classes for pretraining and incremental learning
    num_classes = len(class_index)
    half_classes = num_classes // 4
    pretrain_classes = class_index[:half_classes]
    incremental_classes = class_index[half_classes:]
    incremental_classes = class_index

    # Shuffle the class_index list
    random.shuffle(class_index)

    # Create fake pretraining subset and loader

    # Determine the way, shot, and query for the incremental learning task based on the dataset
    # num_way = 5 if args.dataset in ['cifar100', 'mini_imagenet'] else 10
    # num_shot, num_query = 5, 5  # Defaults for both CIFAR100/MiniImageNet and CUB200
    num_way = args.pseudo_way
    num_shot, num_query = args.pseudo_shot, args.pseudo_shot

    # Initialize lists to hold the indices for support and query sets
    support_indices = []
    query_indices = []



    # Initialize lists to hold the indices for support and query sets
    support_indices = []
    query_indices = []

    # Instead of randomly selecting a few classes, iterate over all classes in incremental_classes
    for c in incremental_classes:
        # Get all indices for class c
        c_indices = np.where(np.array(trainset.targets) == c)[0]
        np.random.shuffle(c_indices)

        # Split indices for support and query sets
        support_c_indices = c_indices[:num_shot]
        query_c_indices = c_indices[num_shot:num_shot + num_query]

        # Add indices to the respective lists
        support_indices.extend(support_c_indices)
        query_indices.extend(query_c_indices)
    # Create support and query subsets and loaders
    support_set = Subset(trainset, support_indices)
    query_set = Subset(trainset, query_indices)

    support_loader = DataLoader(
        dataset=support_set, batch_size=num_way * num_shot, shuffle=False,
        num_workers=args.num_workers, pin_memory=True
    )
    query_loader = DataLoader(
        dataset=query_set, batch_size=num_way * num_query, shuffle=False,
        num_workers=args.num_workers, pin_memory=True
    )

    return support_loader, query_loader


import random


def mixup_samples(data1, data2, num_samples, lam):
    """Create mixed samples from two different classes using vectorized operations."""
    indices1 = torch.randint(0, len(data1), (num_samples,))
    indices2 = torch.randint(0, len(data2), (num_samples,))

    # Directly use indexing and vectorized operations for mixup
    mixed_data = lam * data1[indices1] + (1 - lam) * data2[indices2]
    return mixed_data


def generate_class_combinations(class_index, num_new_classes):
    """ Generate and return fixed class combinations for future use."""
    class_combinations = [(class_index[i], class_index[j]) for i in range(len(class_index)) for j in
                          range(i + 1, len(class_index))]
    selected_combinations = random.sample(class_combinations, num_new_classes)
    return selected_combinations


def get_meta_data_new(trainset, class_index, args, seed, selected_combinations=None):
    np.random.seed(seed)
    random.seed(seed)

    # Prepare new class data and labels
    new_class_data = []
    new_class_labels = []

    # Check if we need to generate new class combinations
    if selected_combinations is None:
        selected_combinations = generate_class_combinations(class_index, args.num_classes - args.base_class)

    num_way = 5 if args.dataset in ['cifar100', 'mini_imagenet'] else 10
    num_shot, num_query = 5, 5

    for idx, (class1, class2) in enumerate(selected_combinations):
        # Fetch class data
        idx1 = np.where(np.array(trainset.targets) == class1)[0]
        idx2 = np.where(np.array(trainset.targets) == class2)[0]

        # Fetch samples and mix
        data1 = torch.stack([trainset[i][0] for i in idx1])
        data2 = torch.stack([trainset[i][0] for i in idx2])
        lam = np.random.uniform(0.4, 0.6)  # Random lambda for mixup
        mixed_data = mixup_samples(data1, data2, num_shot + num_query, lam)
        new_class_data.append(mixed_data)
        new_class_labels.extend([(args.base_class + idx)] * (num_shot + num_query))

    # Concatenate and split data
    new_class_data = torch.cat(new_class_data)
    new_class_labels = torch.tensor(new_class_labels)
    # Initialize lists to store indices for support and query sets
    support_indices = []
    query_indices = []

    # Generate indices for support and query sets for each class combination
    for i in range(len(selected_combinations)):
        start_idx = i * (num_shot + num_query)
        support_indices.extend(range(start_idx, start_idx + num_shot))
        query_indices.extend(range(start_idx + num_shot, start_idx + num_shot + num_query))

    # Convert lists to numpy arrays
    support_indices = np.array(support_indices)
    query_indices = np.array(query_indices)

    # Prepare datasets and loaders
    support_set = TensorDataset(new_class_data[support_indices], new_class_labels[support_indices])
    query_set = TensorDataset(new_class_data[query_indices], new_class_labels[query_indices])
    support_loader = DataLoader(support_set, batch_size=num_way * num_shot, shuffle=False, num_workers=args.num_workers,
                                pin_memory=True)
    query_loader = DataLoader(query_set, batch_size=num_way * num_query, shuffle=False, num_workers=args.num_workers,
                              pin_memory=True)

    return support_loader, query_loader, selected_combinations


# def get_fake_incremental_data(trainset, class_index, args):
#     """
#     Creates a fake incremental task with num_way classes and num_shot + num_query examples per class.
#     """
#     # Calculate the number of classes for pretraining and incremental learning
#     num_classes = len(class_index)
#     half_classes = num_classes // 2
#     pretrain_classes = class_index
#     incremental_classes = class_index
#
#     # Create fake pretraining subset and loader
#     pretrain_indices = np.where(np.isin(np.array(trainset.targets), pretrain_classes))[0]
#     pretrain_set = Subset(trainset, pretrain_indices)
#     pretrain_loader = DataLoader(
#         dataset=pretrain_set, batch_size=args.batch_size_base, shuffle=True,
#         num_workers=args.num_workers, pin_memory=True
#     )
#
#     # Determine the way, shot, and query for the incremental learning task based on the dataset
#     num_way = 5 if args.dataset in ['cifar100', 'mini_imagenet'] else 10
#     num_shot, num_query = 5, 5  # Defaults for both CIFAR100/MiniImageNet and CUB200
#
#     # Ensure the same random state for reproducibility if needed
#     some_seed = 2024
#     np.random.seed(some_seed)
#
#     # Initialize lists to hold the indices for support and query sets
#     support_indices = []
#     query_indices = []
#
#     # Instead of randomly selecting a few classes, iterate over all classes in incremental_classes
#     for c in incremental_classes:
#         # Get all indices for class c
#         c_indices = np.where(np.array(trainset.targets) == c)[0]
#         np.random.shuffle(c_indices)
#
#         # Split indices for support and query sets
#         support_c_indices = c_indices[:num_shot]
#         query_c_indices = c_indices[num_shot:num_shot + num_query]
#
#         # Add indices to the respective lists
#         support_indices.extend(support_c_indices)
#         query_indices.extend(query_c_indices)
#     # Create support and query subsets and loaders
#     support_set = Subset(trainset, support_indices)
#     query_set = Subset(trainset, query_indices)
#
#     support_loader = DataLoader(
#         dataset=support_set, batch_size=num_way * num_shot, shuffle=False,
#         num_workers=args.num_workers, pin_memory=True
#     )
#     query_loader = DataLoader(
#         dataset=query_set, batch_size=num_way * num_query, shuffle=False,
#         num_workers=args.num_workers, pin_memory=True
#     )
#
#     return pretrain_loader, support_loader, query_loader


# def get_new_dataloader(args, session):
#     crop_transform, secondary_transform = get_transform(args)
#     txt_path = "data/index_list/" + args.dataset + "/session_" + str(session + 1) + '.txt'
#
#     if args.dataset == 'cifar100':
#         class_index = open(txt_path).read().splitlines()
#         trainset = args.Dataset.CIFAR100(root=args.dataroot, train=True, download=False,
#                                          index=class_index, base_sess=False)
#     if args.dataset == 'cub200':
#         trainset = args.Dataset.CUB200(root=args.dataroot, train=True,
#                                        index_path=txt_path)
#     if args.dataset == 'mini_imagenet':
#         trainset = args.Dataset.MiniImageNet(root=args.dataroot, train=True,
#                                              index_path=txt_path)
#
#     if args.batch_size_new == 0:
#         batch_size_new = trainset.__len__()
#         trainloader = torch.utils.data.DataLoader(dataset=trainset, batch_size=batch_size_new, shuffle=False,
#                                                   num_workers=args.num_workers, pin_memory=True)
#     else:
#         trainloader = torch.utils.data.DataLoader(dataset=trainset, batch_size=args.batch_size_new, shuffle=True,
#                                                   num_workers=args.num_workers, pin_memory=True)
#
#     # test on all encountered classes
#     class_new = get_session_classes(args, session)
#
#     if args.dataset == 'cifar100':
#         testset = args.Dataset.CIFAR100(root=args.dataroot, train=False, download=False,
#                                         index=class_new, base_sess=False)
#     if args.dataset == 'cub200':
#         testset = args.Dataset.CUB200(root=args.dataroot, train=False,
#                                       index=class_new)
#     if args.dataset == 'mini_imagenet':
#         testset = args.Dataset.MiniImageNet(root=args.dataroot, train=False,
#                                             index=class_new)
#
#     testloader = torch.utils.data.DataLoader(dataset=testset, batch_size=args.test_batch_size, shuffle=False,
#                                              num_workers=args.num_workers, pin_memory=True)
#
#     return trainset, trainloader, testloader


def get_session_classes(args, session):
    class_list = np.arange(args.base_class + session * args.way)
    return class_list
