import os
import pickle
import numpy as np
import random
from PIL import Image

try:
    from . import raw
    from . import utils
except:
    import sys
    sys.path.append(os.path.abspath(__file__))
    import raw
    import utils

RAW_DATA_CALL = {'mnist': raw.load_mnist,
                 'cifar100': raw.load_cifar100}

def build_and_transform(args):
    random.seed(args.seed)
    if args.root_path == None:
        root_path = utils.root_path
    else:
        root_path = args.root_path
    data_path = os.path.join(root_path, args.i, args.raw_data)
    train_data, test_data = RAW_DATA_CALL[args.raw_data](data_path=data_path)

    def rotate_dataset(task):
        
        angle_range =  args.max_angle - args.min_angle
        angle_offset1 = 1.0 * task / args.num_tasks * angle_range
        angle_offset2 = 1.0 * (task + 1) / args.num_tasks * angle_range
        
        min_angle = args.min_angle + angle_offset1
        max_angle = args.min_angle + angle_offset2
        rotation = random.random() * (max_angle - min_angle) + min_angle

        def rotate(img):
            rot_img = Image.fromarray(np.squeeze(img), mode='F').rotate(rotation)
            return np.array(rot_img)

        train_img_T = np.stack([rotate(d) for d in train_data[0]])
        test_img_T = np.stack([rotate(d) for d in test_data[0]])

        train_data_T = (train_img_T, train_data[1])
        test_data_T = (test_img_T, test_data[1])
        return rotation, train_data_T, test_data_T

    train_tasks = []
    test_tasks = []

    for task in range(args.num_tasks):
        rotation, train_data_T, test_data_T = rotate_dataset(task)
        train_tasks.append([rotation, train_data_T[0], train_data_T[1]])
        test_tasks.append([rotation, test_data_T[0], test_data_T[1]])

    utils.save_transformed_data(root_path, train_tasks, test_tasks, args)


if __name__ == "__main__":
    import argparse
    parser = argparse.ArgumentParser()

    parser.add_argument('--root_path', default=None, help='data root path')
    parser.add_argument('--i', default='raw', help='raw datasets for continual learning')
    parser.add_argument('--num_tasks', default=3, type=int, help='number of different tasks')
    parser.add_argument('--seed', default=0, type=int, help='random seed')
    parser.add_argument('--raw_data', default='mnist', help='raw dataset for using continual learning')
    parser.add_argument('--transform', default='rotation', help='transform transformation')
    parser.add_argument('--min_angle', default=0, type=float, help='minimum rotation angle')
    parser.add_argument('--max_angle', default=90, type=float, help='maximum rotation angle')

    args = parser.parse_args()
    build_and_transform(args)