import os
import pickle
import numpy as np
import random
from PIL import Image

try:
    from . import raw
    from . import utils
except:
    import sys
    sys.path.append(os.path.abspath(__file__))
    import raw
    import utils

RAW_DATA_CALL = {'mnist': raw.load_mnist,
                 'cifar100': raw.load_cifar100}

def build_and_transform(args):
    random.seed(args.seed)
    if args.root_path == None:
        root_path = utils.root_path
    else:
        root_path = args.root_path
    data_path = os.path.join(root_path, args.i, args.raw_data)
    train_data, test_data = RAW_DATA_CALL[args.raw_data](data_path=data_path)

    def permute_dataset(task, p):
        

        def perm(img):
            perm_img = np.squeeze(img).flatten()[p].reshape(28,28)
            return perm_img

        train_img_T = np.stack([perm(d) for d in train_data[0]])
        test_img_T = np.stack([perm(d) for d in test_data[0]])

        train_data_T = (train_img_T, train_data[1])
        test_data_T = (test_img_T, test_data[1])
        return train_data_T, test_data_T

    train_tasks = []
    test_tasks = []

    for task in range(args.num_tasks):
        p = np.random.permutation(28*28)
        train_data_T, test_data_T = permute_dataset(task, p)
        train_tasks.append([p, train_data_T[0], train_data_T[1]])
        test_tasks.append([p, test_data_T[0], test_data_T[1]])

    utils.save_transformed_data(root_path, train_tasks, test_tasks, args)


if __name__ == "__main__":
    import argparse
    parser = argparse.ArgumentParser()

    parser.add_argument('--root_path', default=None, help='data root path')
    parser.add_argument('--i', default='raw', help='raw datasets for continual learning')
    parser.add_argument('--num_tasks', default=3, type=int, help='number of different tasks')
    parser.add_argument('--seed', default=0, type=int, help='random seed')
    parser.add_argument('--raw_data', default='mnist', help='raw dataset for using continual learning')
    parser.add_argument('--transform', default='permutation', help='transform transformation')

    args = parser.parse_args()
    build_and_transform(args)