import os
import pickle
import numpy as np
import random
from PIL import Image

try:
    from . import raw
    from . import utils
except:
    import sys
    sys.path.append(os.path.abspath(__file__))
    import raw
    import utils

RAW_DATA_CALL = {'mnist': raw.load_mnist,
                 'cifar100': raw.load_cifar100,
                 'cifar10': raw.load_cifar10}

def build_and_transform(args):
    random.seed(args.seed)
    if args.root_path == None:
        root_path = utils.root_path
    else:
        root_path = args.root_path
    data_path = os.path.join(root_path, args.i, args.raw_data)
    train_data, test_data = RAW_DATA_CALL[args.raw_data](data_path=data_path)

    # def rotate_dataset(task):
        
    #     angle_range =  args.max_angle - args.min_angle
    #     angle_offset1 = 1.0 * task / args.num_tasks * angle_range
    #     angle_offset2 = 1.0 * (task + 1) / args.num_tasks * angle_range
        
    #     min_angle = args.min_angle + angle_offset1
    #     max_angle = args.min_angle + angle_offset2
    #     rotation = random.random() * (max_angle - min_angle) + min_angle

    #     def rotate(img):
    #         rot_img = Image.fromarray(np.squeeze(img), mode='F').rotate(rotation)
    #         return np.array(rot_img)

    #     train_img_T = np.stack([rotate(d) for d in train_data[0]])
    #     test_img_T = np.stack([rotate(d) for d in test_data[0]])

    #     train_data_T = (train_img_T, train_data[1])
    #     test_data_T = (test_img_T, test_data[1])
    #     return rotation, train_data_T, test_data_T

    # print(test_data[0].shape, test_data[1].shape)
    # print(np.max(test_data[1]))
    if test_data[1].shape[-1] % args.num_tasks == 0:
        num_classes_per_task = int(test_data[1].shape[-1] / args.num_tasks)
        print(num_classes_per_task)
    else:
        raise ValueError(f"cannot divide the number of classes {test_data[1].shape[-1]} by the number of tasks {args.num_tasks}")
    train_tasks = []
    test_tasks = []

    def to_int_labels(label_data):
        if len(label_data.shape) > 1:
            label_data = np.argmax(label_data, axis=-1)
        return label_data

    train_labels_to_int = to_int_labels(train_data[1])
    test_labels_to_int = to_int_labels(test_data[1])

    for task in range(args.num_tasks):
        c_start = task * num_classes_per_task
        c_end = (task + 1) * num_classes_per_task
        classes = (c_start, c_end)

        print(test_labels_to_int)

        train_data_index = np.logical_and(train_labels_to_int >= c_start, train_labels_to_int < c_end)
        test_data_index = np.logical_and(test_labels_to_int >=c_start, test_labels_to_int < c_end)

        train_data_T = np.squeeze(train_data[0][train_data_index])
        train_label_T = train_data[1][train_data_index]
        test_data_T = np.squeeze(test_data[0][test_data_index])
        test_label_T = test_data[1][test_data_index]
        print(classes)
        print(test_label_T)
        print(len(train_label_T))

        train_tasks.append([classes, train_data_T, train_label_T])
        test_tasks.append([classes, test_data_T, test_label_T])

    utils.save_transformed_data(root_path, train_tasks, test_tasks, args)


if __name__ == "__main__":
    import argparse
    parser = argparse.ArgumentParser()

    parser.add_argument('--root_path', default=None, help='data root path')
    parser.add_argument('--i', default='raw', help='raw datasets for continual learning')
    parser.add_argument('--num_tasks', default=5, type=int, help='number of different tasks')
    parser.add_argument('--seed', default=0, type=int, help='random seed')
    parser.add_argument('--raw_data', default='mnist', help='raw dataset for using continual learning')
    parser.add_argument('--transform', default='split', help='transform transformation')
    # parser.add_argument('--min_angle', default=0, type=float, help='minimum rotation angle')
    # parser.add_argument('--max_angle', default=90, type=float, help='maximum rotation angle')

    args = parser.parse_args()
    build_and_transform(args)