import tqdm
import math
import torch
import numpy as np
from PIL import Image
from continualai.colab.scripts import mnist
from torch.utils.data import Dataset, DataLoader

def prepare_mnist():
    mnist.init()
    # Load dataset:
    x_train, y_train, x_test, y_test= mnist.load()
    return (x_train, y_train), (x_test, y_test)


def rotate_dataset(d, seed):
    torch.manual_seed(seed)
    rand_rotate = torch.rand(1)
    rotation = 180 * (rand_rotate - 0.5)
    rotated_mnist = np.ndarray((d.shape),np.float32)

    for i in range(d.shape[0]):
        img = Image.fromarray(d[i][0])
        img = img.rotate(rotation)
        rotated_mnist[i, 0]= img
    return rotated_mnist


class ILDataset(Dataset):
    def __init__(self, task_id, train=True):
        self.task_id = task_id
        self.train_flag = train
        (x_train, y_train), (x_test, y_test) = prepare_mnist()
        if self.train_flag:
            self.x = rotate_dataset(x_train, seed=self.task_id)
            self.y = y_train
        else:
            x_list, y_list = list(), list()
            for i in range(1, self.task_id + 1):
                x_list.append(rotate_dataset(x_test, i))
                y_list.append(y_test)
            if len(x_list) > 1:
                self.x = np.concatenate(x_list, axis=0)
                self.y = np.concatenate(y_list)
            else:
                self.x = x_list[0]
                self.y = y_list[0]

    def __len__(self):
        return self.x.shape[0]

    def __getitem__(self, index):
        x, y = self.x[index], self.y[index]
        return x, y


class BufferILDataset(Dataset):
    def __init__(self, task_id, train=True, buffer_size=200):
        self.task_id = task_id
        self.train_flag = train
        (x_train, y_train), (x_test, y_test) = prepare_mnist()
        if self.train_flag:
            self.x = rotate_dataset(x_train, seed=self.task_id)
            self.y = y_train
            if task_id > 1:
                buffer_x, buffer_y = get_buffer(x_train, y_train, buffer_size, task_id)
                self.x = np.concatenate([self.x, buffer_x], axis=0)
                self.y = np.concatenate([self.y, buffer_y])
        else:
            x_list, y_list = list(), list()
            for i in range(1, self.task_id + 1):
                x_list.append(rotate_dataset(x_test, i))
                y_list.append(y_test)
            if len(x_list) > 1:
                self.x = np.concatenate(x_list, axis=0)
                self.y = np.concatenate(y_list)
            else:
                self.x = x_list[0]
                self.y = y_list[0]

    def __len__(self):
        return self.x.shape[0]

    def __getitem__(self, index):
        x, y = self.x[index], self.y[index]
        return x, y


def get_buffer(x, y, buffer_size, task_id):
    buffer_x, buffer_y = list(), list()
    num = math.ceil(buffer_size / task_id / 10)
    for i in range(1, task_id):
        rotate_x = rotate_dataset(x, seed=i)
        for label in range(10):
            # fixed candidates
            candidates = np.array(sorted(np.where(y == label)[0])[:num])
            buffer_x.append(rotate_x[candidates])
            buffer_y.append(y[candidates])
    buffer_x = np.concatenate(buffer_x, axis=0)
    buffer_y = np.concatenate(buffer_y)
    return buffer_x, buffer_y


if __name__ == '__main__':
    # train_set = ILDataset(task_id=1, train=True)
    # test_set = ILDataset(task_id=1, train=False)
    # train_loader = DataLoader(train_set, batch_size=128, num_workers=12, shuffle=True)
    # test_loader = DataLoader(test_set, batch_size=128, num_workers=12, shuffle=False)

    # for x, y in tqdm.tqdm(train_loader):
    #     print(x.shape, y.shape)
    
    # for x, y in tqdm.tqdm(test_loader):
    #     print(x.shape, y.shape)


    train_set = BufferILDataset(task_id=2, train=True)
    test_set = BufferILDataset(task_id=2, train=False)
    train_loader = DataLoader(train_set, batch_size=128, num_workers=12, shuffle=True)
    test_loader = DataLoader(test_set, batch_size=128, num_workers=12, shuffle=False)

    for x, y in tqdm.tqdm(train_loader):
        print(x.shape, y.shape)
    
    for x, y in tqdm.tqdm(test_loader):
        print(x.shape, y.shape)





