import sys
import random
import numpy as np
import torch
from torchvision import transforms
from PIL import Image
import itertools

# Prevent python from saving out .pyc files
sys.dont_write_bytecode = True
# Logging utility
from util import log

# Dimensionality of multiple-choice output
y_dim = 4
# Sequence length
seq_len = 7
# Asymmetric img list
asymmetric_list = range(49)

translate_list = [(i, j) for i in [-9, -6, -3, 0, 3, 6, 9] for j in [-9, -6, -3, 0, 3, 6, 9]]
translate_train_list = random.sample(translate_list, len(translate_list)//2)
translate_test_list = [item for item in translate_list if item not in translate_train_list]

angle_list = [15*i for i in range(24)]
angle_train_list = random.sample(angle_list, len(angle_list)//2)
angle_test_list = [item for item in angle_list if item not in angle_train_list]

# reflection_list = [0, 1, 2, 3]
reflection_list = [0, 1]
reflection_train_list = random.sample(reflection_list, len(reflection_list)//2)
reflection_test_list = [item for item in reflection_list if item not in reflection_train_list]

shear_list = [[i, j] for i in [-60, -45, -30, -15, 0, 15, 30, 45, 60] for j in [-60, -45, -30, -15, 0, 15, 30, 45, 60]]
shear_train_list = random.sample(shear_list, len(shear_list)//2)
shear_test_list = [item for item in shear_list if item not in shear_train_list]

scale_list = [1, 0.75, 0.5, 1.25]
scale_train_list = random.sample(scale_list, len(scale_list)//2)
scale_test_list = [item for item in scale_list if item not in scale_train_list]

pixel_list = range(8, 25)
black_white_list = [(0, 0), (0, 1), (1, 0), (1, 1)]

fisheye_magnitude_list = [0.1, 0.2, 0.3, 0.4, 0.5]
fisheye_center_list = [(i, j) for i in [-0.5, -0.4, -0.3, -0.2, -0.1, 0, 0.1, 0.2, 0.3, 0.4, 0.5]
                       for j in [-0.5, -0.4, -0.3, -0.2, -0.1, 0, 0.1, 0.2, 0.3, 0.4, 0.5]]

hw_freq_list = [10, 15, 20, 25, 30]
hw_amplitude_list = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1]

swap_order_list = list(itertools.permutations([0, 1, 2, 3]))
swap_order_list.remove((0, 1, 2, 3))


def reflect(img, reflection):
    if len(img.shape) == 2:
        if reflection in [0, 1]:
            return torch.flip(img, dims=[reflection])
        elif reflection == 2:
            img_flip = torch.flip(img, dims=[0])
            img_rotated = transforms.functional.affine(img=img_flip.view(1, 32, 32),
                                                       angle=90, translate=(0, 0), scale=1, shear=0, fill=1).reshape(32, 32)
            return img_rotated
        else:
            img_flip = torch.flip(img, dims=[0])
            img_rotated = transforms.functional.affine(img=img_flip.view(1, 32, 32),
                                                       angle=270, translate=(0, 0), scale=1, shear=0, fill=1).reshape(32, 32)
            return img_rotated
    else:
        if reflection in [0, 1]:
            return torch.flip(img, dims=[reflection+1])
        elif reflection == 2:
            img_flip = torch.flip(img, dims=[1])
            img_rotated = transforms.functional.affine(img=img_flip.view(3, 32, 32),
                                                       angle=90, translate=(0, 0), scale=1, shear=0, fill=1).reshape(3, 32, 32)
            return img_rotated
        else:
            img_flip = torch.flip(img, dims=[1])
            img_rotated = transforms.functional.affine(img=img_flip.view(3, 32, 32),
                                                       angle=270, translate=(0, 0), scale=1, shear=0, fill=1).reshape(3, 32, 32)
            return img_rotated


def build_seq(args, shapes, seq_len, train=True):
    all_seq = []
    all_target = []
    if args.transformation_method == 'omniglot-translation':
        import json
        from torch.utils.data import DataLoader
        from torchvision.datasets import Omniglot

        with open('tasks/omniglot_train_class_id_dict.json') as f:
            train_class_id_dict = json.load(f)

        data_dir = 'data'
        trans = transforms.Compose([transforms.ToTensor(),
                                    transforms.Resize((32, 32))])

        trainset = Omniglot(data_dir, background=True, download=True, transform=trans)
        train_loader_mnist = DataLoader(trainset, batch_size=19280, shuffle=False)  # NO SHUFFLE HERE!
        all_imgs = next(iter(train_loader_mnist))[0].squeeze(1)

        for _ in range(seq_len):
            class_i = random.choice(shapes.tolist())
            class_j = random.choice([item for item in shapes.tolist() if item != class_i])
            class_k = random.choice([item for item in shapes.tolist() if item != class_i and item != class_j])

            i = random.choice(train_class_id_dict[str(class_i)])
            j = random.choice(train_class_id_dict[str(class_j)])
            k = random.choice(train_class_id_dict[str(class_k)])

            multiple_choice_list = [i, j, j, k]
            random.shuffle(multiple_choice_list)
            target = random.choice([index for index in [t for t in range(4) if multiple_choice_list[t] == j]])

            translate = random.choice(translate_list)

            img_list = []
            img_list.append(all_imgs[i].view(1, 32, 32))
            img_list.append(transforms.functional.affine(img=all_imgs[i].view(1, 32, 32),
                                                         angle=0, translate=translate, scale=1, shear=0, fill=1))
            img_list.append(all_imgs[j].view(1, 32, 32))
            for t in range(4):
                if multiple_choice_list[t] == i:
                    img_list.append(transforms.functional.affine(img=all_imgs[i].view(1, 32, 32),
                                                                 angle=0, translate=translate, scale=1, shear=0,
                                                                 fill=1))
                elif multiple_choice_list[t] == j:
                    if t == target:
                        img_list.append(transforms.functional.affine(img=all_imgs[j].view(1, 32, 32),
                                                                     angle=0, translate=translate, scale=1,
                                                                     shear=0, fill=1))
                    else:
                        translate_false = random.choice([item for item in translate_list if item != translate])
                        img_list.append(transforms.functional.affine(img=all_imgs[j].view(1, 32, 32),
                                                                     angle=0, translate=translate_false, scale=1,
                                                                     shear=0, fill=1))
                elif multiple_choice_list[t] == k:
                    img_list.append(all_imgs[k].view(1, 32, 32))
            img_list = torch.cat(img_list, 0)

            all_seq.append(img_list)
            all_target.append(target)
    elif args.transformation_method == 'omniglot-rotation':
        import json
        from torch.utils.data import DataLoader
        from torchvision.datasets import Omniglot

        with open('tasks/omniglot_train_class_id_dict.json') as f:
            train_class_id_dict = json.load(f)

        data_dir = 'data'
        trans = transforms.Compose([transforms.ToTensor(),
                                    transforms.Resize((32, 32))])

        trainset = Omniglot(data_dir, background=True, download=True, transform=trans)
        train_loader_mnist = DataLoader(trainset, batch_size=19280, shuffle=False)  # NO SHUFFLE HERE!
        all_imgs = next(iter(train_loader_mnist))[0].squeeze(1)

        for _ in range(seq_len):
            class_i = random.choice(shapes.tolist())
            class_j = random.choice([item for item in shapes.tolist() if item != class_i])
            class_k = random.choice([item for item in shapes.tolist() if item != class_i and item != class_j])

            i = random.choice(train_class_id_dict[str(class_i)])
            j = random.choice(train_class_id_dict[str(class_j)])
            k = random.choice(train_class_id_dict[str(class_k)])

            multiple_choice_list = [i, j, j, k]
            random.shuffle(multiple_choice_list)
            target = random.choice([index for index in [t for t in range(4) if multiple_choice_list[t] == j]])

            angle = random.choice(angle_list)

            img_list = []
            img_list.append(all_imgs[i].view(1, 32, 32))
            img_list.append(transforms.functional.affine(img=all_imgs[i].view(1, 32, 32),
                                                         angle=angle, translate=(0, 0), scale=1, shear=0, fill=1))
            img_list.append(all_imgs[j].view(1, 32, 32))
            for t in range(4):
                if multiple_choice_list[t] == i:
                    img_list.append(transforms.functional.affine(img=all_imgs[i].view(1, 32, 32),
                                                                 angle=angle, translate=(0, 0), scale=1, shear=0,
                                                                 fill=1))
                elif multiple_choice_list[t] == j:
                    if t == target:
                        img_list.append(transforms.functional.affine(img=all_imgs[j].view(1, 32, 32),
                                                                     angle=angle, translate=(0, 0), scale=1,
                                                                     shear=0, fill=1))
                    else:
                        angle_false = random.choice([item for item in angle_list if item != angle])
                        img_list.append(transforms.functional.affine(img=all_imgs[j].view(1, 32, 32),
                                                                     angle=angle_false, translate=(0, 0), scale=1,
                                                                     shear=0, fill=1))
                elif multiple_choice_list[t] == k:
                    img_list.append(all_imgs[k].view(1, 32, 32))
            img_list = torch.cat(img_list, 0)

            all_seq.append(img_list)
            all_target.append(target)
    elif args.transformation_method == 'omniglot-shear':
        import json
        from torch.utils.data import DataLoader
        from torchvision.datasets import Omniglot

        with open('tasks/omniglot_train_class_id_dict.json') as f:
            train_class_id_dict = json.load(f)

        data_dir = 'data'
        trans = transforms.Compose([transforms.ToTensor(),
                                    transforms.Resize((32, 32))])

        trainset = Omniglot(data_dir, background=True, download=True, transform=trans)
        train_loader_mnist = DataLoader(trainset, batch_size=19280, shuffle=False)  # NO SHUFFLE HERE!
        all_imgs = next(iter(train_loader_mnist))[0].squeeze(1)

        for _ in range(seq_len):
            class_i = random.choice(shapes.tolist())
            class_j = random.choice([item for item in shapes.tolist() if item != class_i])
            class_k = random.choice([item for item in shapes.tolist() if item != class_i and item != class_j])

            i = random.choice(train_class_id_dict[str(class_i)])
            j = random.choice(train_class_id_dict[str(class_j)])
            k = random.choice(train_class_id_dict[str(class_k)])

            multiple_choice_list = [i, j, j, k]
            random.shuffle(multiple_choice_list)
            target = random.choice([index for index in [t for t in range(4) if multiple_choice_list[t] == j]])

            shear = random.choice(shear_list)

            img_list = []
            img_list.append(all_imgs[i].view(1, 32, 32))
            img_list.append(transforms.functional.affine(img=all_imgs[i].view(1, 32, 32),
                                                         angle=0, translate=(0, 0), scale=1, shear=shear, fill=1))
            img_list.append(all_imgs[j].view(1, 32, 32))
            for t in range(4):
                if multiple_choice_list[t] == i:
                    img_list.append(transforms.functional.affine(img=all_imgs[i].view(1, 32, 32),
                                                                 angle=0, translate=(0, 0), scale=1, shear=shear,
                                                                 fill=1))
                elif multiple_choice_list[t] == j:
                    if t == target:
                        img_list.append(transforms.functional.affine(img=all_imgs[j].view(1, 32, 32),
                                                                     angle=0, translate=(0, 0), scale=1,
                                                                     shear=shear, fill=1))
                    else:
                        shear_false = random.choice([item for item in shear_list if item != shear])
                        img_list.append(transforms.functional.affine(img=all_imgs[j].view(1, 32, 32),
                                                                     angle=0, translate=(0, 0), scale=1,
                                                                     shear=shear_false, fill=1))
                elif multiple_choice_list[t] == k:
                    img_list.append(all_imgs[k].view(1, 32, 32))
            img_list = torch.cat(img_list, 0)

            all_seq.append(img_list)
            all_target.append(target)
    elif args.transformation_method == 'omniglot-scale':
        import json
        from torch.utils.data import DataLoader
        from torchvision.datasets import Omniglot

        with open('tasks/omniglot_train_class_id_dict.json') as f:
            train_class_id_dict = json.load(f)

        data_dir = 'data'
        trans = transforms.Compose([transforms.ToTensor(),
                                    transforms.Resize((32, 32))])

        trainset = Omniglot(data_dir, background=True, download=True, transform=trans)
        train_loader_mnist = DataLoader(trainset, batch_size=19280, shuffle=False)  # NO SHUFFLE HERE!
        all_imgs = next(iter(train_loader_mnist))[0].squeeze(1)

        for _ in range(seq_len):
            class_i = random.choice(shapes.tolist())
            class_j = random.choice([item for item in shapes.tolist() if item != class_i])
            class_k = random.choice([item for item in shapes.tolist() if item != class_i and item != class_j])

            i = random.choice(train_class_id_dict[str(class_i)])
            j = random.choice(train_class_id_dict[str(class_j)])
            k = random.choice(train_class_id_dict[str(class_k)])

            multiple_choice_list = [i, j, j, k]
            random.shuffle(multiple_choice_list)
            target = random.choice([index for index in [t for t in range(4) if multiple_choice_list[t] == j]])

            scale = random.choice(scale_list)

            img_list = []
            img_list.append(all_imgs[i].view(1, 32, 32))
            img_list.append(transforms.functional.affine(img=all_imgs[i].view(1, 32, 32),
                                                         angle=0, translate=(0, 0), scale=scale, shear=0, fill=1))
            img_list.append(all_imgs[j].view(1, 32, 32))
            for t in range(4):
                if multiple_choice_list[t] == i:
                    img_list.append(transforms.functional.affine(img=all_imgs[i].view(1, 32, 32),
                                                                 angle=0, translate=(0, 0), scale=scale, shear=0,
                                                                 fill=1))
                elif multiple_choice_list[t] == j:
                    if t == target:
                        img_list.append(transforms.functional.affine(img=all_imgs[j].view(1, 32, 32),
                                                                     angle=0, translate=(0, 0), scale=scale,
                                                                     shear=0, fill=1))
                    else:
                        scale_false = random.choice([item for item in scale_list if item != scale])
                        img_list.append(transforms.functional.affine(img=all_imgs[j].view(1, 32, 32),
                                                                     angle=scale_false, translate=(0, 0), scale=1,
                                                                     shear=0, fill=1))
                elif multiple_choice_list[t] == k:
                    img_list.append(all_imgs[k].view(1, 32, 32))
            img_list = torch.cat(img_list, 0)

            all_seq.append(img_list)
            all_target.append(target)
    elif args.transformation_method == 'omniglot-reflection':
        import json
        from torch.utils.data import DataLoader
        from torchvision.datasets import Omniglot

        with open('tasks/omniglot_train_class_id_dict.json') as f:
            train_class_id_dict = json.load(f)

        data_dir = 'data'
        trans = transforms.Compose([transforms.ToTensor(),
                                    transforms.Resize((32, 32))])

        trainset = Omniglot(data_dir, background=True, download=True, transform=trans)
        train_loader_mnist = DataLoader(trainset, batch_size=19280, shuffle=False)  # NO SHUFFLE HERE!
        all_imgs = next(iter(train_loader_mnist))[0].squeeze(1)

        for _ in range(seq_len):
            class_i = random.choice(shapes.tolist())
            class_j = random.choice([item for item in shapes.tolist() if item != class_i])
            class_k = random.choice([item for item in shapes.tolist() if item != class_i and item != class_j])

            i = random.choice(train_class_id_dict[str(class_i)])
            j = random.choice(train_class_id_dict[str(class_j)])
            k = random.choice(train_class_id_dict[str(class_k)])

            multiple_choice_list = [i, j, j, k]
            random.shuffle(multiple_choice_list)
            target = random.choice([index for index in [t for t in range(4) if multiple_choice_list[t] == j]])

            reflection = random.choice(reflection_list)

            img_list = []
            img_list.append(all_imgs[i].view(1, 32, 32))
            img_list.append(reflect(all_imgs[i], reflection=reflection).view(1, 32, 32))
            img_list.append(all_imgs[j].view(1, 32, 32))

            for t in range(4):
                if multiple_choice_list[t] == i:
                    img_list.append(reflect(all_imgs[i], reflection=reflection).view(1, 32, 32))
                elif multiple_choice_list[t] == j:
                    if t == target:
                        img_list.append(reflect(all_imgs[j], reflection=reflection).view(1, 32, 32))
                    else:
                        reflection_false = random.choice([item for item in reflection_list if item != reflection])
                        img_list.append(reflect(all_imgs[j], reflection=reflection_false).view(1, 32, 32))
                elif multiple_choice_list[t] == k:
                    img_list.append(all_imgs[k].view(1, 32, 32))
            img_list = torch.cat(img_list, 0)

            all_seq.append(img_list)
            all_target.append(target)
    elif args.transformation_method == 'omniglot-all':
        import json
        from torch.utils.data import DataLoader
        from torchvision.datasets import Omniglot

        with open('tasks/omniglot_train_class_id_dict.json') as f:
            train_class_id_dict = json.load(f)

        data_dir = 'data'
        trans = transforms.Compose([transforms.ToTensor(),
                                    transforms.Resize((32, 32))])

        trainset = Omniglot(data_dir, background=True, download=True, transform=trans)
        train_loader_mnist = DataLoader(trainset, batch_size=19280, shuffle=False)  # NO SHUFFLE HERE!
        all_imgs = next(iter(train_loader_mnist))[0].squeeze(1)

        for _ in range(seq_len):
            class_i = random.choice(shapes.tolist())
            class_j = random.choice([item for item in shapes.tolist() if item != class_i])
            class_k = random.choice([item for item in shapes.tolist() if item != class_i and item != class_j])

            i = random.choice(train_class_id_dict[str(class_i)])
            j = random.choice(train_class_id_dict[str(class_j)])
            k = random.choice(train_class_id_dict[str(class_k)])

            multiple_choice_list = [i, j, j, k]
            random.shuffle(multiple_choice_list)
            target = random.choice([index for index in [t for t in range(4) if multiple_choice_list[t] == j]])

            transformation_list = [item.strip() for item in args.transformation_mixture.split(',')]
            # transformation_list = ['translation', 'rotation', 'shear', 'scale', 'reflection']
            transformation = random.choice(transformation_list)

            img_list = []
            if transformation == 'translation':
                if args.generalization_type == 'object':
                    translate = random.choice(translate_list)
                elif args.generalization_type == 'object+function':
                    if train:
                        translate = random.choice(translate_train_list)
                    else:
                        translate = random.choice(translate_test_list)

                img_list.append(all_imgs[i].view(1, 32, 32))
                img_list.append(transforms.functional.affine(img=all_imgs[i].view(1, 32, 32),
                                                             angle=0, translate=translate, scale=1, shear=0, fill=1))
                img_list.append(all_imgs[j].view(1, 32, 32))
                for t in range(4):
                    if multiple_choice_list[t] == i:
                        img_list.append(transforms.functional.affine(img=all_imgs[i].view(1, 32, 32),
                                                                     angle=0, translate=translate, scale=1, shear=0,
                                                                     fill=1))
                    elif multiple_choice_list[t] == j:
                        if t == target:
                            img_list.append(transforms.functional.affine(img=all_imgs[j].view(1, 32, 32),
                                                                         angle=0, translate=translate, scale=1,
                                                                         shear=0, fill=1))
                        else:
                            translate_false = random.choice([item for item in translate_list if item != translate])
                            img_list.append(transforms.functional.affine(img=all_imgs[j].view(1, 32, 32),
                                                                         angle=0, translate=translate_false, scale=1,
                                                                         shear=0, fill=1))
                    elif multiple_choice_list[t] == k:
                        img_list.append(all_imgs[k].view(1, 32, 32))
            elif transformation == 'rotation':
                if args.generalization_type == 'object':
                    angle = random.choice(angle_list)
                elif args.generalization_type == 'object+function':
                    if train:
                        angle = random.choice(angle_train_list)
                    else:
                        angle = random.choice(angle_test_list)

                img_list.append(all_imgs[i].view(1, 32, 32))
                img_list.append(transforms.functional.affine(img=all_imgs[i].view(1, 32, 32),
                                                             angle=angle, translate=(0, 0), scale=1, shear=0, fill=1))
                img_list.append(all_imgs[j].view(1, 32, 32))
                for t in range(4):
                    if multiple_choice_list[t] == i:
                        img_list.append(transforms.functional.affine(img=all_imgs[i].view(1, 32, 32),
                                                                     angle=angle, translate=(0, 0), scale=1, shear=0,
                                                                     fill=1))
                    elif multiple_choice_list[t] == j:
                        if t == target:
                            img_list.append(transforms.functional.affine(img=all_imgs[j].view(1, 32, 32),
                                                                         angle=angle, translate=(0, 0), scale=1,
                                                                         shear=0, fill=1))
                        else:
                            angle_false = random.choice([item for item in angle_list if item != angle])
                            img_list.append(transforms.functional.affine(img=all_imgs[j].view(1, 32, 32),
                                                                         angle=angle_false, translate=(0, 0), scale=1,
                                                                         shear=0, fill=1))
                    elif multiple_choice_list[t] == k:
                        img_list.append(all_imgs[k].view(1, 32, 32))
            elif transformation == 'shear':
                if args.generalization_type == 'object':
                    shear = random.choice(shear_list)
                elif args.generalization_type == 'object+function':
                    if train:
                        shear = random.choice(shear_train_list)
                    else:
                        shear = random.choice(shear_test_list)

                img_list.append(all_imgs[i].view(1, 32, 32))
                img_list.append(transforms.functional.affine(img=all_imgs[i].view(1, 32, 32),
                                                             angle=0, translate=(0, 0), scale=1, shear=shear, fill=1))
                img_list.append(all_imgs[j].view(1, 32, 32))
                for t in range(4):
                    if multiple_choice_list[t] == i:
                        img_list.append(transforms.functional.affine(img=all_imgs[i].view(1, 32, 32),
                                                                     angle=0, translate=(0, 0), scale=1, shear=shear,
                                                                     fill=1))
                    elif multiple_choice_list[t] == j:
                        if t == target:
                            img_list.append(transforms.functional.affine(img=all_imgs[j].view(1, 32, 32),
                                                                         angle=0, translate=(0, 0), scale=1,
                                                                         shear=shear, fill=1))
                        else:
                            shear_false = random.choice([item for item in shear_list if item != shear])
                            img_list.append(transforms.functional.affine(img=all_imgs[j].view(1, 32, 32),
                                                                         angle=0, translate=(0, 0), scale=1,
                                                                         shear=shear_false, fill=1))
                    elif multiple_choice_list[t] == k:
                        img_list.append(all_imgs[k].view(1, 32, 32))
            elif transformation == 'scale':
                if args.generalization_type == 'object':
                    scale = random.choice(scale_list)
                elif args.generalization_type == 'object+function':
                    if train:
                        scale = random.choice(scale_train_list)
                    else:
                        scale = random.choice(scale_test_list)

                img_list.append(all_imgs[i].view(1, 32, 32))
                img_list.append(transforms.functional.affine(img=all_imgs[i].view(1, 32, 32),
                                                             angle=0, translate=(0, 0), scale=scale, shear=0, fill=1))
                img_list.append(all_imgs[j].view(1, 32, 32))
                for t in range(4):
                    if multiple_choice_list[t] == i:
                        img_list.append(transforms.functional.affine(img=all_imgs[i].view(1, 32, 32),
                                                                     angle=0, translate=(0, 0), scale=scale, shear=0,
                                                                     fill=1))
                    elif multiple_choice_list[t] == j:
                        if t == target:
                            img_list.append(transforms.functional.affine(img=all_imgs[j].view(1, 32, 32),
                                                                         angle=0, translate=(0, 0), scale=scale,
                                                                         shear=0, fill=1))
                        else:
                            scale_false = random.choice([item for item in scale_list if item != scale])
                            img_list.append(transforms.functional.affine(img=all_imgs[j].view(1, 32, 32),
                                                                         angle=scale_false, translate=(0, 0), scale=1,
                                                                         shear=0, fill=1))
                    elif multiple_choice_list[t] == k:
                        img_list.append(all_imgs[k].view(1, 32, 32))
            elif transformation == 'reflection':
                if args.generalization_type == 'object':
                    reflection = random.choice(reflection_list)
                elif args.generalization_type == 'object+function':
                    if train:
                        reflection = random.choice(reflection_train_list)
                    else:
                        reflection = random.choice(reflection_test_list)

                img_list = []
                img_list.append(all_imgs[i].view(1, 32, 32))
                img_list.append(reflect(all_imgs[i], reflection=reflection).view(1, 32, 32))
                img_list.append(all_imgs[j].view(1, 32, 32))

                for t in range(4):
                    if multiple_choice_list[t] == i:
                        img_list.append(reflect(all_imgs[i], reflection=reflection).view(1, 32, 32))
                    elif multiple_choice_list[t] == j:
                        if t == target:
                            img_list.append(reflect(all_imgs[j], reflection=reflection).view(1, 32, 32))
                        else:
                            reflection_false = random.choice([item for item in reflection_list if item != reflection])
                            img_list.append(reflect(all_imgs[j], reflection=reflection_false).view(1, 32, 32))
                    elif multiple_choice_list[t] == k:
                        img_list.append(all_imgs[k].view(1, 32, 32))
            img_list = torch.cat(img_list, 0)

            all_seq.append(img_list)
            all_target.append(target)
    elif args.transformation_method == 'omniglot-black-white':
        def black_white(img, split_pixel, p):
            if p == (0, 0):
                img_1 = img[:split_pixel, :]
                img_2 = img[split_pixel:, :]
                img_bw = np.concatenate((1 - img_1, img_2), axis=0)
            elif p == (0, 1):
                img_1 = img[:split_pixel, :]
                img_2 = img[split_pixel:, :]
                img_bw = np.concatenate((img_1, 1 - img_2), axis=0)
            elif p == (1, 0):
                img_1 = img[:, :split_pixel]
                img_2 = img[:, split_pixel:]
                img_bw = np.concatenate((1 - img_1, img_2), axis=1)
            else:
                img_1 = img[:, :split_pixel]
                img_2 = img[:, split_pixel:]
                img_bw = np.concatenate((img_1, 1 - img_2), axis=1)
            return torch.Tensor(img_bw)

        import json
        from torch.utils.data import DataLoader
        from torchvision.datasets import Omniglot

        with open('tasks/omniglot_train_class_id_dict.json') as f:
            train_class_id_dict = json.load(f)

        data_dir = 'data'
        trans = transforms.Compose([transforms.ToTensor(),
                                    transforms.Resize((32, 32))])

        trainset = Omniglot(data_dir, background=True, download=True, transform=trans)
        train_loader_mnist = DataLoader(trainset, batch_size=19280, shuffle=False)  # NO SHUFFLE HERE!
        all_imgs = next(iter(train_loader_mnist))[0].squeeze(1)

        for _ in range(seq_len):
            class_i = random.choice(shapes.tolist())
            class_j = random.choice([item for item in shapes.tolist() if item != class_i])
            class_k = random.choice([item for item in shapes.tolist() if item != class_i and item != class_j])

            i = random.choice(train_class_id_dict[str(class_i)])
            j = random.choice(train_class_id_dict[str(class_j)])
            k = random.choice(train_class_id_dict[str(class_k)])

            multiple_choice_list = [i, j, j, k]
            random.shuffle(multiple_choice_list)
            target = random.choice([index for index in [t for t in range(4) if multiple_choice_list[t] == j]])

            split_pixel = random.choice(pixel_list)
            bw = random.choice(black_white_list)

            img_list = []
            img_list.append(all_imgs[i].view(1, 32, 32))
            img_list.append(black_white(all_imgs[i], split_pixel=split_pixel, p=bw).view(1, 32, 32))
            img_list.append(all_imgs[j].view(1, 32, 32))

            for t in range(4):
                if multiple_choice_list[t] == i:
                    img_list.append(black_white(all_imgs[i], split_pixel=split_pixel, p=bw).view(1, 32, 32))
                elif multiple_choice_list[t] == j:
                    if t == target:
                        img_list.append(black_white(all_imgs[j], split_pixel=split_pixel, p=bw).view(1, 32, 32))
                    else:
                        rand = np.random.uniform()
                        if rand < 0.5:
                            bw_false = random.choice([item for item in black_white_list if item != bw])
                            img_list.append(black_white(all_imgs[j], split_pixel=split_pixel, p=bw_false).view(1, 32, 32))
                        else:
                            split_pixel_false = random.choice([item for item in pixel_list if item != split_pixel])
                            img_list.append(black_white(all_imgs[j], split_pixel=split_pixel_false, p=bw).view(1, 32, 32))
                elif multiple_choice_list[t] == k:
                    img_list.append(black_white(all_imgs[k], split_pixel=split_pixel, p=bw).view(1, 32, 32))
            img_list = torch.cat(img_list, 0)

            all_seq.append(img_list)
            all_target.append(target)
    elif args.transformation_method == 'omniglot-swap':
        def swap(img, order):
            subimg_list = [img[:16, :16], img[:16, 16:], img[16:, :16], img[16:, 16:]]
            row_concat_1 = torch.cat((subimg_list[order[0]], subimg_list[order[1]]), dim=1)
            row_concat_2 = torch.cat((subimg_list[order[2]], subimg_list[order[3]]), dim=1)
            img_swap = torch.cat((row_concat_1, row_concat_2), dim=0)
            return img_swap

        import json
        from torch.utils.data import DataLoader
        from torchvision.datasets import Omniglot

        with open('tasks/omniglot_train_class_id_dict.json') as f:
            train_class_id_dict = json.load(f)

        data_dir = 'data'
        trans = transforms.Compose([transforms.ToTensor(),
                                    transforms.Resize((32, 32))])

        trainset = Omniglot(data_dir, background=True, download=True, transform=trans)
        train_loader_mnist = DataLoader(trainset, batch_size=19280, shuffle=False)  # NO SHUFFLE HERE!
        all_imgs = next(iter(train_loader_mnist))[0].squeeze(1)

        for _ in range(seq_len):
            class_i = random.choice(shapes.tolist())
            class_j = random.choice([item for item in shapes.tolist() if item != class_i])
            class_k = random.choice([item for item in shapes.tolist() if item != class_i and item != class_j])

            i = random.choice(train_class_id_dict[str(class_i)])
            j = random.choice(train_class_id_dict[str(class_j)])
            k = random.choice(train_class_id_dict[str(class_k)])

            multiple_choice_list = [i, j, j, k]
            random.shuffle(multiple_choice_list)
            target = random.choice([index for index in [t for t in range(4) if multiple_choice_list[t] == j]])

            swap_order = random.choice(swap_order_list)

            img_list = []
            img_list.append(all_imgs[i].view(1, 32, 32))
            img_list.append(swap(all_imgs[i], order=swap_order).view(1, 32, 32))
            img_list.append(all_imgs[j].view(1, 32, 32))

            for t in range(4):
                if multiple_choice_list[t] == i:
                    img_list.append(swap(all_imgs[i], order=swap_order).view(1, 32, 32))
                elif multiple_choice_list[t] == j:
                    if t == target:
                        img_list.append(swap(all_imgs[j], order=swap_order).view(1, 32, 32))
                    else:
                        swap_order_false = random.choice([item for item in swap_order_list if item != swap_order])
                        img_list.append(swap(all_imgs[j], order=swap_order_false).view(1, 32, 32))
                elif multiple_choice_list[t] == k:
                    img_list.append(swap(all_imgs[k], order=swap_order).view(1, 32, 32))
            img_list = torch.cat(img_list, 0)

            all_seq.append(img_list)
            all_target.append(target)
    elif args.transformation_method == 'omniglot-fisheye':
        import json
        from torch.utils.data import DataLoader
        from torchvision.datasets import Omniglot
        import torch.nn.functional as F

        def get_grid_fisheye(H, W, center, magnitude):
            xx, yy = torch.linspace(-1, 1, W), torch.linspace(-1, 1, H)
            gridy, gridx = torch.meshgrid(yy, xx)
            grid = torch.stack([gridx, gridy], dim=-1)
            d = center - grid
            d_sum = torch.sqrt((d ** 2).sum(axis=-1))
            grid += d * d_sum.unsqueeze(-1) * magnitude
            return grid.unsqueeze(0)

        def fisheye(img, center, magnitude):
            fisheye_grid = get_grid_fisheye(32, 32, torch.tensor(list(center)), magnitude)
            fisheye_output = F.grid_sample(img.unsqueeze(0).unsqueeze(0), fisheye_grid)
            return fisheye_output.reshape(32, 32)

        with open('tasks/omniglot_train_class_id_dict.json') as f:
            train_class_id_dict = json.load(f)

        data_dir = 'data'
        trans = transforms.Compose([transforms.ToTensor(),
                                    transforms.Resize((32, 32))])

        trainset = Omniglot(data_dir, background=True, download=True, transform=trans)
        train_loader_mnist = DataLoader(trainset, batch_size=19280, shuffle=False)  # NO SHUFFLE HERE!
        all_imgs = next(iter(train_loader_mnist))[0].squeeze(1)

        for _ in range(seq_len):
            class_i = random.choice(shapes.tolist())
            class_j = random.choice([item for item in shapes.tolist() if item != class_i])
            class_k = random.choice([item for item in shapes.tolist() if item != class_i and item != class_j])

            i = random.choice(train_class_id_dict[str(class_i)])
            j = random.choice(train_class_id_dict[str(class_j)])
            k = random.choice(train_class_id_dict[str(class_k)])

            multiple_choice_list = [i, j, j, k]
            random.shuffle(multiple_choice_list)
            target = random.choice([index for index in [t for t in range(4) if multiple_choice_list[t] == j]])

            center = random.choice(fisheye_center_list)
            magnitude = random.choice(fisheye_magnitude_list)

            img_list = []
            img_list.append(all_imgs[i].view(1, 32, 32))
            img_list.append(fisheye(all_imgs[i], center=center, magnitude=magnitude).view(1, 32, 32))
            img_list.append(all_imgs[j].view(1, 32, 32))

            for t in range(4):
                if multiple_choice_list[t] == i:
                    img_list.append(fisheye(all_imgs[i], center=center, magnitude=magnitude).view(1, 32, 32))
                elif multiple_choice_list[t] == j:
                    if t == target:
                        img_list.append(fisheye(all_imgs[j], center=center, magnitude=magnitude).view(1, 32, 32))
                    else:
                        rand = np.random.uniform()
                        if rand < 0.5:
                            magnitude_false = random.choice([item for item in fisheye_magnitude_list if item != magnitude])
                            img_list.append(fisheye(all_imgs[j], center=center, magnitude=magnitude_false).view(1, 32, 32))
                        else:
                            center_false = random.choice([item for item in fisheye_center_list if item != center])
                            img_list.append(fisheye(all_imgs[j], center=center_false, magnitude=magnitude).view(1, 32, 32))
                elif multiple_choice_list[t] == k:
                    img_list.append(fisheye(all_imgs[k], center=center, magnitude=magnitude).view(1, 32, 32))
            img_list = torch.cat(img_list, 0)

            all_seq.append(img_list)
            all_target.append(target)
    elif args.transformation_method == 'omniglot-horizontal-wave':
        import json
        from torch.utils.data import DataLoader
        from torchvision.datasets import Omniglot
        import torch.nn.functional as F

        def get_grid_horizontalwave(height, width, freq, amplitude):
            xx, yy = torch.linspace(-1, 1, width), torch.linspace(-1, 1, height)
            gridy, gridx = torch.meshgrid(yy, xx)  # create identity grid
            grid = torch.stack([gridx, gridy], dim=-1)
            dy = amplitude * torch.cos(freq * grid[:, :, 0])  # calculate dy
            grid[:, :, 1] += dy
            return grid.unsqueeze(0)  # unsqueeze(0) since the grid needs to be 4D.

        def horizontal_wave(img, freq, amplitude):
            hw_grid = get_grid_horizontalwave(32, 32, freq, amplitude)
            hw_output = F.grid_sample(img.unsqueeze(0).unsqueeze(0), hw_grid)
            return hw_output.reshape(32, 32)

        with open('tasks/omniglot_train_class_id_dict.json') as f:
            train_class_id_dict = json.load(f)

        data_dir = 'data'
        trans = transforms.Compose([transforms.ToTensor(),
                                    transforms.Resize((32, 32))])

        trainset = Omniglot(data_dir, background=True, download=True, transform=trans)
        train_loader_mnist = DataLoader(trainset, batch_size=19280, shuffle=False)  # NO SHUFFLE HERE!
        all_imgs = next(iter(train_loader_mnist))[0].squeeze(1)

        for _ in range(seq_len):
            class_i = random.choice(shapes.tolist())
            class_j = random.choice([item for item in shapes.tolist() if item != class_i])
            class_k = random.choice([item for item in shapes.tolist() if item != class_i and item != class_j])

            i = random.choice(train_class_id_dict[str(class_i)])
            j = random.choice(train_class_id_dict[str(class_j)])
            k = random.choice(train_class_id_dict[str(class_k)])

            multiple_choice_list = [i, j, j, k]
            random.shuffle(multiple_choice_list)
            target = random.choice([index for index in [t for t in range(4) if multiple_choice_list[t] == j]])

            freq = random.choice(hw_freq_list)
            amplitude = random.choice(hw_amplitude_list)

            img_list = []
            img_list.append(all_imgs[i].view(1, 32, 32))
            img_list.append(horizontal_wave(all_imgs[i], freq=freq, amplitude=amplitude).view(1, 32, 32))
            img_list.append(all_imgs[j].view(1, 32, 32))

            for t in range(4):
                if multiple_choice_list[t] == i:
                    img_list.append(horizontal_wave(all_imgs[i], freq=freq, amplitude=amplitude).view(1, 32, 32))
                elif multiple_choice_list[t] == j:
                    if t == target:
                        img_list.append(horizontal_wave(all_imgs[j], freq=freq, amplitude=amplitude).view(1, 32, 32))
                    else:
                        rand = np.random.uniform()
                        if rand < 0.5:
                            amplitude_false = random.choice([item for item in hw_amplitude_list if item != amplitude])
                            img_list.append(horizontal_wave(all_imgs[j], freq=freq, amplitude=amplitude_false).view(1, 32, 32))
                        else:
                            freq_false = random.choice([item for item in hw_freq_list if item != freq])
                            img_list.append(horizontal_wave(all_imgs[j], freq=freq_false, amplitude=amplitude).view(1, 32, 32))
                elif multiple_choice_list[t] == k:
                    img_list.append(horizontal_wave(all_imgs[k], freq=freq, amplitude=amplitude).view(1, 32, 32))
            img_list = torch.cat(img_list, 0)

            all_seq.append(img_list)
            all_target.append(target)
    elif args.transformation_method == 'cifar100-translation':
        import json
        from torch.utils.data import DataLoader
        from torchvision.datasets import CIFAR100

        with open('tasks/cifar100_train_class_id_dict.json') as f:
            cifar100_train_class_id_dict = json.load(f)

        data_dir = 'data'
        trans = transforms.Compose([transforms.ToTensor(),
                                    # transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
                                    ]
                                   )

        trainset = CIFAR100(data_dir, train=True, download=True, transform=trans)
        train_loader = DataLoader(trainset, batch_size=60000, shuffle=False)  # NO SHUFFLE HERE!
        all_imgs = next(iter(train_loader))[0].squeeze(1)

        for _ in range(seq_len):
            class_i = random.choice(shapes.tolist())
            class_j = random.choice([item for item in shapes.tolist() if item != class_i])
            class_k = random.choice([item for item in shapes.tolist() if item != class_i and item != class_j])

            i = random.choice(cifar100_train_class_id_dict[str(class_i)])
            j = random.choice(cifar100_train_class_id_dict[str(class_j)])
            k = random.choice(cifar100_train_class_id_dict[str(class_k)])

            multiple_choice_list = [i, j, j, k]
            random.shuffle(multiple_choice_list)
            target = random.choice([index for index in [t for t in range(4) if multiple_choice_list[t] == j]])

            translate = random.choice(translate_list)

            img_list = []
            img_list.append(all_imgs[i].view(3, 32, 32))
            img_list.append(transforms.functional.affine(img=all_imgs[i].view(3, 32, 32),
                                                         angle=0, translate=translate, scale=1, shear=0, fill=0))
            img_list.append(all_imgs[j].view(3, 32, 32))
            for t in range(4):
                if multiple_choice_list[t] == i:
                    img_list.append(transforms.functional.affine(img=all_imgs[i].view(3, 32, 32),
                                                                 angle=0, translate=translate, scale=1, shear=0,
                                                                 fill=0))
                elif multiple_choice_list[t] == j:
                    if t == target:
                        img_list.append(transforms.functional.affine(img=all_imgs[j].view(3, 32, 32),
                                                                     angle=0, translate=translate, scale=1,
                                                                     shear=0, fill=0))
                    else:
                        translate_false = random.choice([item for item in translate_list if item != translate])
                        img_list.append(transforms.functional.affine(img=all_imgs[j].view(3, 32, 32),
                                                                     angle=0, translate=translate_false, scale=1,
                                                                     shear=0, fill=0))
                elif multiple_choice_list[t] == k:
                    img_list.append(all_imgs[k].view(3, 32, 32))
            img_list = torch.stack(img_list, 0)

            all_seq.append(img_list)
            all_target.append(target)
    elif args.transformation_method == 'cifar100-rotation':
        import json
        from torch.utils.data import DataLoader
        from torchvision.datasets import CIFAR100

        with open('tasks/cifar100_train_class_id_dict.json') as f:
            cifar100_train_class_id_dict = json.load(f)

        data_dir = 'data'
        trans = transforms.Compose([transforms.ToTensor(),
                                    # transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
                                    ]
                                   )

        trainset = CIFAR100(data_dir, train=True, download=True, transform=trans)
        train_loader = DataLoader(trainset, batch_size=60000, shuffle=False)  # NO SHUFFLE HERE!
        all_imgs = next(iter(train_loader))[0].squeeze(1)

        for _ in range(seq_len):
            class_i = random.choice(shapes.tolist())
            class_j = random.choice([item for item in shapes.tolist() if item != class_i])
            class_k = random.choice([item for item in shapes.tolist() if item != class_i and item != class_j])

            i = random.choice(cifar100_train_class_id_dict[str(class_i)])
            j = random.choice(cifar100_train_class_id_dict[str(class_j)])
            k = random.choice(cifar100_train_class_id_dict[str(class_k)])

            multiple_choice_list = [i, j, j, k]
            random.shuffle(multiple_choice_list)
            target = random.choice([index for index in [t for t in range(4) if multiple_choice_list[t] == j]])

            angle = random.choice(angle_list)

            img_list = []
            img_list.append(all_imgs[i].view(3, 32, 32))
            img_list.append(transforms.functional.affine(img=all_imgs[i].view(3, 32, 32),
                                                         angle=angle, translate=(0, 0), scale=1, shear=0, fill=1))
            img_list.append(all_imgs[j].view(3, 32, 32))
            for t in range(4):
                if multiple_choice_list[t] == i:
                    img_list.append(transforms.functional.affine(img=all_imgs[i].view(3, 32, 32),
                                                                 angle=angle, translate=(0, 0), scale=1, shear=0,
                                                                 fill=1))
                elif multiple_choice_list[t] == j:
                    if t == target:
                        img_list.append(transforms.functional.affine(img=all_imgs[j].view(3, 32, 32),
                                                                     angle=angle, translate=(0, 0), scale=1,
                                                                     shear=0, fill=1))
                    else:
                        angle_false = random.choice([item for item in angle_list if item != angle])
                        img_list.append(transforms.functional.affine(img=all_imgs[j].view(3, 32, 32),
                                                                     angle=angle_false, translate=(0, 0), scale=1,
                                                                     shear=0, fill=1))
                elif multiple_choice_list[t] == k:
                    img_list.append(all_imgs[k].view(3, 32, 32))
            img_list = torch.stack(img_list, 0)

            all_seq.append(img_list)
            all_target.append(target)
    elif args.transformation_method == 'cifar100-reflection':
        import json
        from torch.utils.data import DataLoader
        from torchvision.datasets import CIFAR100

        with open('tasks/cifar100_train_class_id_dict.json') as f:
            cifar100_train_class_id_dict = json.load(f)

        data_dir = 'data'
        trans = transforms.Compose([transforms.ToTensor(),
                                    # transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
                                    ]
                                   )

        trainset = CIFAR100(data_dir, train=True, download=True, transform=trans)
        train_loader = DataLoader(trainset, batch_size=60000, shuffle=False)  # NO SHUFFLE HERE!
        all_imgs = next(iter(train_loader))[0].squeeze(1)

        for _ in range(seq_len):
            class_i = random.choice(shapes.tolist())
            class_j = random.choice([item for item in shapes.tolist() if item != class_i])
            class_k = random.choice([item for item in shapes.tolist() if item != class_i and item != class_j])

            i = random.choice(cifar100_train_class_id_dict[str(class_i)])
            j = random.choice(cifar100_train_class_id_dict[str(class_j)])
            k = random.choice(cifar100_train_class_id_dict[str(class_k)])

            multiple_choice_list = [i, j, j, k]
            random.shuffle(multiple_choice_list)
            target = random.choice([index for index in [t for t in range(4) if multiple_choice_list[t] == j]])

            reflection = random.choice(reflection_list)

            img_list = []
            img_list.append(all_imgs[i].view(3, 32, 32))
            img_list.append(reflect(all_imgs[i], reflection=reflection).view(3, 32, 32))
            img_list.append(all_imgs[j].view(3, 32, 32))

            for t in range(4):
                if multiple_choice_list[t] == i:
                    img_list.append(reflect(all_imgs[i], reflection=reflection).view(3, 32, 32))
                elif multiple_choice_list[t] == j:
                    if t == target:
                        img_list.append(reflect(all_imgs[j], reflection=reflection).view(3, 32, 32))
                    else:
                        reflection_false = random.choice([item for item in reflection_list if item != reflection])
                        img_list.append(reflect(all_imgs[j], reflection=reflection_false).view(3, 32, 32))
                elif multiple_choice_list[t] == k:
                    img_list.append(all_imgs[k].view(3, 32, 32))
            img_list = torch.stack(img_list, 0)

            all_seq.append(img_list)
            all_target.append(target)
    elif args.transformation_method == 'cifar100-shear':
        import json
        from torch.utils.data import DataLoader
        from torchvision.datasets import CIFAR100

        with open('tasks/cifar100_train_class_id_dict.json') as f:
            cifar100_train_class_id_dict = json.load(f)

        data_dir = 'data'
        trans = transforms.Compose([transforms.ToTensor(),
                                    # transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
                                    ]
                                   )

        trainset = CIFAR100(data_dir, train=True, download=True, transform=trans)
        train_loader = DataLoader(trainset, batch_size=60000, shuffle=False)  # NO SHUFFLE HERE!
        all_imgs = next(iter(train_loader))[0].squeeze(1)

        for _ in range(seq_len):
            class_i = random.choice(shapes.tolist())
            class_j = random.choice([item for item in shapes.tolist() if item != class_i])
            class_k = random.choice([item for item in shapes.tolist() if item != class_i and item != class_j])

            i = random.choice(cifar100_train_class_id_dict[str(class_i)])
            j = random.choice(cifar100_train_class_id_dict[str(class_j)])
            k = random.choice(cifar100_train_class_id_dict[str(class_k)])

            multiple_choice_list = [i, j, j, k]
            random.shuffle(multiple_choice_list)
            target = random.choice([index for index in [t for t in range(4) if multiple_choice_list[t] == j]])

            shear = random.choice(shear_list)

            img_list = []
            img_list.append(all_imgs[i].view(3, 32, 32))
            img_list.append(transforms.functional.affine(img=all_imgs[i].view(3, 32, 32),
                                                         angle=0, translate=(0, 0), scale=1, shear=shear, fill=1))
            img_list.append(all_imgs[j].view(3, 32, 32))
            for t in range(4):
                if multiple_choice_list[t] == i:
                    img_list.append(transforms.functional.affine(img=all_imgs[i].view(3, 32, 32),
                                                                 angle=0, translate=(0, 0), scale=1, shear=shear,
                                                                 fill=1))
                elif multiple_choice_list[t] == j:
                    if t == target:
                        img_list.append(transforms.functional.affine(img=all_imgs[j].view(3, 32, 32),
                                                                     angle=0, translate=(0, 0), scale=1,
                                                                     shear=shear, fill=1))
                    else:
                        shear_false = random.choice([item for item in shear_list if item != shear])
                        img_list.append(transforms.functional.affine(img=all_imgs[j].view(3, 32, 32),
                                                                     angle=0, translate=(0, 0), scale=1,
                                                                     shear=shear_false, fill=1))
                elif multiple_choice_list[t] == k:
                    img_list.append(all_imgs[k].view(3, 32, 32))
            img_list = torch.stack(img_list, 0)

            all_seq.append(img_list)
            all_target.append(target)
    elif args.transformation_method == 'cifar100-scale':
        import json
        from torch.utils.data import DataLoader
        from torchvision.datasets import CIFAR100

        with open('tasks/cifar100_train_class_id_dict.json') as f:
            cifar100_train_class_id_dict = json.load(f)

        data_dir = 'data'
        trans = transforms.Compose([transforms.ToTensor(),
                                    # transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
                                    ]
                                   )

        trainset = CIFAR100(data_dir, train=True, download=True, transform=trans)
        train_loader = DataLoader(trainset, batch_size=60000, shuffle=False)  # NO SHUFFLE HERE!
        all_imgs = next(iter(train_loader))[0].squeeze(1)

        for _ in range(seq_len):
            class_i = random.choice(shapes.tolist())
            class_j = random.choice([item for item in shapes.tolist() if item != class_i])
            class_k = random.choice([item for item in shapes.tolist() if item != class_i and item != class_j])

            i = random.choice(cifar100_train_class_id_dict[str(class_i)])
            j = random.choice(cifar100_train_class_id_dict[str(class_j)])
            k = random.choice(cifar100_train_class_id_dict[str(class_k)])

            multiple_choice_list = [i, j, j, k]
            random.shuffle(multiple_choice_list)
            target = random.choice([index for index in [t for t in range(4) if multiple_choice_list[t] == j]])

            scale = random.choice(scale_list)

            img_list = []
            img_list.append(all_imgs[i].view(3, 32, 32))
            img_list.append(transforms.functional.affine(img=all_imgs[i].view(3, 32, 32),
                                                         angle=0, translate=(0, 0), scale=scale, shear=0, fill=1))
            img_list.append(all_imgs[j].view(3, 32, 32))
            for t in range(4):
                if multiple_choice_list[t] == i:
                    img_list.append(transforms.functional.affine(img=all_imgs[i].view(3, 32, 32),
                                                                 angle=0, translate=(0, 0), scale=scale, shear=0,
                                                                 fill=1))
                elif multiple_choice_list[t] == j:
                    if t == target:
                        img_list.append(transforms.functional.affine(img=all_imgs[j].view(3, 32, 32),
                                                                     angle=0, translate=(0, 0), scale=scale,
                                                                     shear=0, fill=1))
                    else:
                        scale_false = random.choice([item for item in scale_list if item != scale])
                        img_list.append(transforms.functional.affine(img=all_imgs[j].view(3, 32, 32),
                                                                     angle=scale_false, translate=(0, 0), scale=1,
                                                                     shear=0, fill=1))
                elif multiple_choice_list[t] == k:
                    img_list.append(all_imgs[k].view(3, 32, 32))
            img_list = torch.stack(img_list, 0)

            all_seq.append(img_list)
            all_target.append(target)
    elif args.transformation_method == 'cifar100-all':
        import json
        from torch.utils.data import DataLoader
        from torchvision.datasets import CIFAR100

        with open('tasks/cifar100_train_class_id_dict.json') as f:
            cifar100_train_class_id_dict = json.load(f)

        data_dir = 'data'
        trans = transforms.Compose([transforms.ToTensor(),
                                    # transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
                                    ]
                                   )

        trainset = CIFAR100(data_dir, train=True, download=True, transform=trans)
        train_loader = DataLoader(trainset, batch_size=60000, shuffle=False)  # NO SHUFFLE HERE!
        all_imgs = next(iter(train_loader))[0].squeeze(1)

        for _ in range(seq_len):
            class_i = random.choice(shapes.tolist())
            class_j = random.choice([item for item in shapes.tolist() if item != class_i])
            class_k = random.choice([item for item in shapes.tolist() if item != class_i and item != class_j])

            i = random.choice(cifar100_train_class_id_dict[str(class_i)])
            j = random.choice(cifar100_train_class_id_dict[str(class_j)])
            k = random.choice(cifar100_train_class_id_dict[str(class_k)])

            multiple_choice_list = [i, j, j, k]
            random.shuffle(multiple_choice_list)
            target = random.choice([index for index in [t for t in range(4) if multiple_choice_list[t] == j]])

            transformation_list = [item.strip() for item in args.transformation_mixture.split(',')]
            # transformation_list = ['translation', 'rotation', 'shear', 'scale', 'reflection']
            transformation = random.choice(transformation_list)

            img_list = []
            if transformation == 'translation':
                if args.generalization_type == 'object':
                    translate = random.choice(translate_list)
                elif args.generalization_type == 'object+function':
                    if train:
                        translate = random.choice(translate_train_list)
                    else:
                        translate = random.choice(translate_test_list)

                img_list.append(all_imgs[i].view(3, 32, 32))
                img_list.append(transforms.functional.affine(img=all_imgs[i].view(3, 32, 32),
                                                             angle=0, translate=translate, scale=1, shear=0, fill=1))
                img_list.append(all_imgs[j].view(3, 32, 32))
                for t in range(4):
                    if multiple_choice_list[t] == i:
                        img_list.append(transforms.functional.affine(img=all_imgs[i].view(3, 32, 32),
                                                                     angle=0, translate=translate, scale=1, shear=0,
                                                                     fill=1))
                    elif multiple_choice_list[t] == j:
                        if t == target:
                            img_list.append(transforms.functional.affine(img=all_imgs[j].view(3, 32, 32),
                                                                         angle=0, translate=translate, scale=1,
                                                                         shear=0, fill=1))
                        else:
                            translate_false = random.choice([item for item in translate_list if item != translate])
                            img_list.append(transforms.functional.affine(img=all_imgs[j].view(3, 32, 32),
                                                                         angle=0, translate=translate_false, scale=1,
                                                                         shear=0, fill=1))
                    elif multiple_choice_list[t] == k:
                        img_list.append(all_imgs[k].view(3, 32, 32))
            elif transformation == 'rotation':
                if args.generalization_type == 'object':
                    angle = random.choice(angle_list)
                elif args.generalization_type == 'object+function':
                    if train:
                        angle = random.choice(angle_train_list)
                    else:
                        angle = random.choice(angle_test_list)

                img_list.append(all_imgs[i].view(3, 32, 32))
                img_list.append(transforms.functional.affine(img=all_imgs[i].view(3, 32, 32),
                                                             angle=angle, translate=(0, 0), scale=1, shear=0, fill=1))
                img_list.append(all_imgs[j].view(3, 32, 32))
                for t in range(4):
                    if multiple_choice_list[t] == i:
                        img_list.append(transforms.functional.affine(img=all_imgs[i].view(3, 32, 32),
                                                                     angle=angle, translate=(0, 0), scale=1, shear=0,
                                                                     fill=1))
                    elif multiple_choice_list[t] == j:
                        if t == target:
                            img_list.append(transforms.functional.affine(img=all_imgs[j].view(3, 32, 32),
                                                                         angle=angle, translate=(0, 0), scale=1,
                                                                         shear=0, fill=1))
                        else:
                            angle_false = random.choice([item for item in angle_list if item != angle])
                            img_list.append(transforms.functional.affine(img=all_imgs[j].view(3, 32, 32),
                                                                         angle=angle_false, translate=(0, 0), scale=1,
                                                                         shear=0, fill=1))
                    elif multiple_choice_list[t] == k:
                        img_list.append(all_imgs[k].view(3, 32, 32))
            elif transformation == 'shear':
                if args.generalization_type == 'object':
                    shear = random.choice(shear_list)
                elif args.generalization_type == 'object+function':
                    if train:
                        shear = random.choice(shear_train_list)
                    else:
                        shear = random.choice(shear_test_list)

                img_list.append(all_imgs[i].view(3, 32, 32))
                img_list.append(transforms.functional.affine(img=all_imgs[i].view(3, 32, 32),
                                                             angle=0, translate=(0, 0), scale=1, shear=shear, fill=1))
                img_list.append(all_imgs[j].view(3, 32, 32))
                for t in range(4):
                    if multiple_choice_list[t] == i:
                        img_list.append(transforms.functional.affine(img=all_imgs[i].view(3, 32, 32),
                                                                     angle=0, translate=(0, 0), scale=1, shear=shear,
                                                                     fill=1))
                    elif multiple_choice_list[t] == j:
                        if t == target:
                            img_list.append(transforms.functional.affine(img=all_imgs[j].view(3, 32, 32),
                                                                         angle=0, translate=(0, 0), scale=1,
                                                                         shear=shear, fill=1))
                        else:
                            shear_false = random.choice([item for item in shear_list if item != shear])
                            img_list.append(transforms.functional.affine(img=all_imgs[j].view(3, 32, 32),
                                                                         angle=0, translate=(0, 0), scale=1,
                                                                         shear=shear_false, fill=1))
                    elif multiple_choice_list[t] == k:
                        img_list.append(all_imgs[k].view(3, 32, 32))
            elif transformation == 'scale':
                if args.generalization_type == 'object':
                    scale = random.choice(scale_list)
                elif args.generalization_type == 'object+function':
                    if train:
                        scale = random.choice(scale_train_list)
                    else:
                        scale = random.choice(scale_test_list)

                img_list.append(all_imgs[i].view(3, 32, 32))
                img_list.append(transforms.functional.affine(img=all_imgs[i].view(3, 32, 32),
                                                             angle=0, translate=(0, 0), scale=scale, shear=0, fill=1))
                img_list.append(all_imgs[j].view(3, 32, 32))
                for t in range(4):
                    if multiple_choice_list[t] == i:
                        img_list.append(transforms.functional.affine(img=all_imgs[i].view(3, 32, 32),
                                                                     angle=0, translate=(0, 0), scale=scale, shear=0,
                                                                     fill=1))
                    elif multiple_choice_list[t] == j:
                        if t == target:
                            img_list.append(transforms.functional.affine(img=all_imgs[j].view(3, 32, 32),
                                                                         angle=0, translate=(0, 0), scale=scale,
                                                                         shear=0, fill=1))
                        else:
                            scale_false = random.choice([item for item in scale_list if item != scale])
                            img_list.append(transforms.functional.affine(img=all_imgs[j].view(3, 32, 32),
                                                                         angle=scale_false, translate=(0, 0), scale=1,
                                                                         shear=0, fill=1))
                    elif multiple_choice_list[t] == k:
                        img_list.append(all_imgs[k].view(3, 32, 32))
            elif transformation == 'reflection':
                if args.generalization_type == 'object':
                    reflection = random.choice(reflection_list)
                elif args.generalization_type == 'object+function':
                    if train:
                        reflection = random.choice(reflection_train_list)
                    else:
                        reflection = random.choice(reflection_test_list)

                img_list = []
                img_list.append(all_imgs[i].view(3, 32, 32))
                img_list.append(reflect(all_imgs[i], reflection=reflection).view(3, 32, 32))
                img_list.append(all_imgs[j].view(3, 32, 32))

                for t in range(4):
                    if multiple_choice_list[t] == i:
                        img_list.append(reflect(all_imgs[i], reflection=reflection).view(3, 32, 32))
                    elif multiple_choice_list[t] == j:
                        if t == target:
                            img_list.append(reflect(all_imgs[j], reflection=reflection).view(3, 32, 32))
                        else:
                            reflection_false = random.choice([item for item in reflection_list if item != reflection])
                            img_list.append(reflect(all_imgs[j], reflection=reflection_false).view(3, 32, 32))
                    elif multiple_choice_list[t] == k:
                        img_list.append(all_imgs[k].view(3, 32, 32))
            img_list = torch.stack(img_list, 0)

            all_seq.append(img_list)
            all_target.append(target)
    elif args.transformation_method == 'cifar100-horizontal-wave':
        import json
        from torch.utils.data import DataLoader
        from torchvision.datasets import Omniglot
        import torch.nn.functional as F

        def get_grid_horizontalwave(height, width, freq, amplitude):
            xx, yy = torch.linspace(-1, 1, width), torch.linspace(-1, 1, height)
            gridy, gridx = torch.meshgrid(yy, xx)  # create identity grid
            grid = torch.stack([gridx, gridy], dim=-1)
            dy = amplitude * torch.cos(freq * grid[:, :, 0])  # calculate dy
            grid[:, :, 1] += dy
            return grid.unsqueeze(0)  # unsqueeze(0) since the grid needs to be 4D.

        def horizontal_wave(img, freq, amplitude):
            hw_grid = get_grid_horizontalwave(32, 32, freq, amplitude)
            hw_output = F.grid_sample(img.unsqueeze(0), hw_grid)
            return hw_output.reshape(3, 32, 32)

        import json
        from torch.utils.data import DataLoader
        from torchvision.datasets import CIFAR100

        with open('tasks/cifar100_train_class_id_dict.json') as f:
            cifar100_train_class_id_dict = json.load(f)

        data_dir = 'data'
        trans = transforms.Compose([transforms.ToTensor(),
                                    # transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
                                    ]
                                   )

        trainset = CIFAR100(data_dir, train=True, download=True, transform=trans)
        train_loader = DataLoader(trainset, batch_size=60000, shuffle=False)  # NO SHUFFLE HERE!
        all_imgs = next(iter(train_loader))[0].squeeze(1)

        for _ in range(seq_len):
            class_i = random.choice(shapes.tolist())
            class_j = random.choice([item for item in shapes.tolist() if item != class_i])
            class_k = random.choice([item for item in shapes.tolist() if item != class_i and item != class_j])

            i = random.choice(cifar100_train_class_id_dict[str(class_i)])
            j = random.choice(cifar100_train_class_id_dict[str(class_j)])
            k = random.choice(cifar100_train_class_id_dict[str(class_k)])

            multiple_choice_list = [i, j, j, k]
            random.shuffle(multiple_choice_list)
            target = random.choice([index for index in [t for t in range(4) if multiple_choice_list[t] == j]])

            freq = random.choice(hw_freq_list)
            amplitude = random.choice(hw_amplitude_list)

            img_list = []
            img_list.append(all_imgs[i].view(3, 32, 32))
            img_list.append(horizontal_wave(all_imgs[i], freq=freq, amplitude=amplitude).view(3, 32, 32))
            img_list.append(all_imgs[j].view(3, 32, 32))

            for t in range(4):
                if multiple_choice_list[t] == i:
                    img_list.append(horizontal_wave(all_imgs[i], freq=freq, amplitude=amplitude).view(3, 32, 32))
                elif multiple_choice_list[t] == j:
                    if t == target:
                        img_list.append(horizontal_wave(all_imgs[j], freq=freq, amplitude=amplitude).view(3, 32, 32))
                    else:
                        rand = np.random.uniform()
                        if rand < 0.5:
                            amplitude_false = random.choice([item for item in hw_amplitude_list if item != amplitude])
                            img_list.append(horizontal_wave(all_imgs[j], freq=freq, amplitude=amplitude_false).view(3, 32, 32))
                        else:
                            freq_false = random.choice([item for item in hw_freq_list if item != freq])
                            img_list.append(horizontal_wave(all_imgs[j], freq=freq_false, amplitude=amplitude).view(3, 32, 32))
                elif multiple_choice_list[t] == k:
                    img_list.append(horizontal_wave(all_imgs[k], freq=freq, amplitude=amplitude).view(3, 32, 32))
            img_list = torch.stack(img_list, 0)

            all_seq.append(img_list)
            all_target.append(target)
    elif args.transformation_method == 'cifar100-fisheye':
        import json
        from torch.utils.data import DataLoader
        from torchvision.datasets import Omniglot
        import torch.nn.functional as F

        def get_grid_fisheye(H, W, center, magnitude):
            xx, yy = torch.linspace(-1, 1, W), torch.linspace(-1, 1, H)
            gridy, gridx = torch.meshgrid(yy, xx)
            grid = torch.stack([gridx, gridy], dim=-1)
            d = center - grid
            d_sum = torch.sqrt((d ** 2).sum(axis=-1))
            grid += d * d_sum.unsqueeze(-1) * magnitude
            return grid.unsqueeze(0)

        def fisheye(img, center, magnitude):
            fisheye_grid = get_grid_fisheye(32, 32, torch.tensor(list(center)), magnitude)
            fisheye_output = F.grid_sample(img.unsqueeze(0), fisheye_grid)
            return fisheye_output.reshape(3, 32, 32)

        import json
        from torch.utils.data import DataLoader
        from torchvision.datasets import CIFAR100

        with open('tasks/cifar100_train_class_id_dict.json') as f:
            cifar100_train_class_id_dict = json.load(f)

        data_dir = 'data'
        trans = transforms.Compose([transforms.ToTensor(),
                                    # transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
                                    ]
                                   )

        trainset = CIFAR100(data_dir, train=True, download=True, transform=trans)
        train_loader = DataLoader(trainset, batch_size=60000, shuffle=False)  # NO SHUFFLE HERE!
        all_imgs = next(iter(train_loader))[0].squeeze(1)

        for _ in range(seq_len):
            class_i = random.choice(shapes.tolist())
            class_j = random.choice([item for item in shapes.tolist() if item != class_i])
            class_k = random.choice([item for item in shapes.tolist() if item != class_i and item != class_j])

            i = random.choice(cifar100_train_class_id_dict[str(class_i)])
            j = random.choice(cifar100_train_class_id_dict[str(class_j)])
            k = random.choice(cifar100_train_class_id_dict[str(class_k)])

            multiple_choice_list = [i, j, j, k]
            random.shuffle(multiple_choice_list)
            target = random.choice([index for index in [t for t in range(4) if multiple_choice_list[t] == j]])

            center = random.choice(fisheye_center_list)
            magnitude = random.choice(fisheye_magnitude_list)

            img_list = []
            img_list.append(all_imgs[i].view(3, 32, 32))
            img_list.append(fisheye(all_imgs[i], center=center, magnitude=magnitude).view(3, 32, 32))
            img_list.append(all_imgs[j].view(3, 32, 32))

            for t in range(4):
                if multiple_choice_list[t] == i:
                    img_list.append(fisheye(all_imgs[i], center=center, magnitude=magnitude).view(3, 32, 32))
                elif multiple_choice_list[t] == j:
                    if t == target:
                        img_list.append(fisheye(all_imgs[j], center=center, magnitude=magnitude).view(3, 32, 32))
                    else:
                        rand = np.random.uniform()
                        if rand < 0.5:
                            magnitude_false = random.choice([item for item in fisheye_magnitude_list if item != magnitude])
                            img_list.append(fisheye(all_imgs[j], center=center, magnitude=magnitude_false).view(3, 32, 32))
                        else:
                            center_false = random.choice([item for item in fisheye_center_list if item != center])
                            img_list.append(fisheye(all_imgs[j], center=center_false, magnitude=magnitude).view(3, 32, 32))
                elif multiple_choice_list[t] == k:
                    img_list.append(fisheye(all_imgs[k], center=center, magnitude=magnitude).view(3, 32, 32))
            img_list = torch.stack(img_list, 0)

            all_seq.append(img_list)
            all_target.append(target)
    elif args.transformation_method == 'cifar100-black-white':
        def gray_scale(img):
            img_mean = torch.mean(img, dim=0)
            img_gray_scale = torch.stack([img_mean]*3, dim=0)
            return img_gray_scale

        def black_white(img, split_pixel, p):
            if p == (0, 0):
                img_1 = img[:, :split_pixel, :]
                img_2 = img[:, split_pixel:, :]
                img_bw = np.concatenate((1-img_1, img_2), axis=1)
            elif p == (0, 1):
                img_1 = img[:, :split_pixel, :]
                img_2 = img[:, split_pixel:, :]
                img_bw = np.concatenate((img_1, 1-img_2), axis=1)
            elif p == (1, 0):
                img_1 = img[:, :, :split_pixel]
                img_2 = img[:, :, split_pixel:]
                img_bw = np.concatenate((1-img_1, img_2), axis=2)
            else:
                img_1 = img[:, :, :split_pixel]
                img_2 = img[:, :, split_pixel:]
                img_bw = np.concatenate((img_1, 1-img_2), axis=2)
            return torch.Tensor(img_bw)

        import json
        from torch.utils.data import DataLoader
        from torchvision.datasets import CIFAR100

        with open('tasks/cifar100_train_class_id_dict.json') as f:
            cifar100_train_class_id_dict = json.load(f)

        data_dir = 'data'
        trans = transforms.Compose([transforms.ToTensor(),
                                    # transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
                                    ]
                                   )

        trainset = CIFAR100(data_dir, train=True, download=True, transform=trans)
        train_loader = DataLoader(trainset, batch_size=60000, shuffle=False)  # NO SHUFFLE HERE!
        all_imgs = next(iter(train_loader))[0].squeeze(1)

        for _ in range(seq_len):
            class_i = random.choice(shapes.tolist())
            class_j = random.choice([item for item in shapes.tolist() if item != class_i])
            class_k = random.choice([item for item in shapes.tolist() if item != class_i and item != class_j])

            i = random.choice(cifar100_train_class_id_dict[str(class_i)])
            j = random.choice(cifar100_train_class_id_dict[str(class_j)])
            k = random.choice(cifar100_train_class_id_dict[str(class_k)])

            multiple_choice_list = [i, j, j, k]
            random.shuffle(multiple_choice_list)
            target = random.choice([index for index in [t for t in range(4) if multiple_choice_list[t] == j]])

            split_pixel = random.choice(pixel_list)
            bw = random.choice(black_white_list)

            img_list = []
            img_list.append(all_imgs[i].view(3, 32, 32))
            img_list.append(black_white(all_imgs[i], split_pixel=split_pixel, p=bw).view(3, 32, 32))
            img_list.append(all_imgs[j].view(3, 32, 32))

            for t in range(4):
                if multiple_choice_list[t] == i:
                    img_list.append(black_white(all_imgs[i], split_pixel=split_pixel, p=bw).view(3, 32, 32))
                elif multiple_choice_list[t] == j:
                    if t == target:
                        img_list.append(black_white(all_imgs[j], split_pixel=split_pixel, p=bw).view(3, 32, 32))
                    else:
                        rand = np.random.uniform()
                        if rand < 0.5:
                            bw_false = random.choice([item for item in black_white_list if item != bw])
                            img_list.append(black_white(all_imgs[j], split_pixel=split_pixel, p=bw_false).view(3, 32, 32))
                        else:
                            split_pixel_false = random.choice([item for item in pixel_list if item != split_pixel])
                            img_list.append(black_white(all_imgs[j], split_pixel=split_pixel_false, p=bw).view(3, 32, 32))
                elif multiple_choice_list[t] == k:
                    img_list.append(black_white(all_imgs[k], split_pixel=split_pixel, p=bw).view(3, 32, 32))
            img_list = torch.stack(img_list, 0)

            all_seq.append(img_list)
            all_target.append(target)
    elif args.transformation_method == 'cifar100-swap':
        import json
        from torch.utils.data import DataLoader
        from torchvision.datasets import CIFAR100

        def swap(img, order):
            subimg_list = [img[:, :16, :16], img[:, :16, 16:], img[:, 16:, :16], img[:, 16:, 16:]]
            row_concat_1 = torch.cat((subimg_list[order[0]], subimg_list[order[1]]), dim=-1)
            row_concat_2 = torch.cat((subimg_list[order[2]], subimg_list[order[3]]), dim=-1)
            img_swap = torch.cat((row_concat_1, row_concat_2), dim=-2)
            return img_swap

        with open('tasks/cifar100_train_class_id_dict.json') as f:
            cifar100_train_class_id_dict = json.load(f)

        data_dir = 'data'
        trans = transforms.Compose([transforms.ToTensor(),
                                    # transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
                                    ]
                                   )

        trainset = CIFAR100(data_dir, train=True, download=True, transform=trans)
        train_loader = DataLoader(trainset, batch_size=60000, shuffle=False)  # NO SHUFFLE HERE!
        all_imgs = next(iter(train_loader))[0].squeeze(1)

        for _ in range(seq_len):
            class_i = random.choice(shapes.tolist())
            class_j = random.choice([item for item in shapes.tolist() if item != class_i])
            class_k = random.choice([item for item in shapes.tolist() if item != class_i and item != class_j])

            i = random.choice(cifar100_train_class_id_dict[str(class_i)])
            j = random.choice(cifar100_train_class_id_dict[str(class_j)])
            k = random.choice(cifar100_train_class_id_dict[str(class_k)])

            multiple_choice_list = [i, j, j, k]
            random.shuffle(multiple_choice_list)
            target = random.choice([index for index in [t for t in range(4) if multiple_choice_list[t] == j]])

            swap_order = random.choice(swap_order_list)

            img_list = []
            img_list.append(all_imgs[i].view(3, 32, 32))
            img_list.append(swap(all_imgs[i], order=swap_order).view(3, 32, 32))
            img_list.append(all_imgs[j].view(3, 32, 32))

            for t in range(4):
                if multiple_choice_list[t] == i:
                    img_list.append(swap(all_imgs[i], order=swap_order).view(3, 32, 32))
                elif multiple_choice_list[t] == j:
                    if t == target:
                        img_list.append(swap(all_imgs[j], order=swap_order).view(3, 32, 32))
                    else:
                        swap_order_false = random.choice([item for item in swap_order_list if item != swap_order])
                        img_list.append(swap(all_imgs[j], order=swap_order_false).view(3, 32, 32))
                elif multiple_choice_list[t] == k:
                    img_list.append(swap(all_imgs[k], order=swap_order).view(3, 32, 32))
            img_list = torch.stack(img_list, 0)

            all_seq.append(img_list)
            all_target.append(target)

    all_seq = torch.stack(all_seq, dim=0)
    return all_seq, all_target


# Task generator
def create_task(args, train_shapes, test_shapes):
    all_train_seq, all_train_target = build_seq(args=args, shapes=train_shapes, seq_len=args.num_train, train=True)
    all_test_seq, all_test_target = build_seq(args=args, shapes=test_shapes, seq_len=args.num_test, train=False)

    train_set = {'seq': all_train_seq, 'y': all_train_target}
    test_set = {'seq': all_test_seq, 'y': all_test_target}

    return args, train_set, test_set



