import random
import torch
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor
from torch.utils.data import DataLoader
from itertools import chain
import numpy as np
import os
import torchvision.transforms as T
from torchvision.io import read_image


def zip_scalar(xs, a):
    return [(x, a) for x in xs]


# function which turns a pyotorch dataset list into a dict of class ids and list of inputs for that class
def make_class_dict(data):
    class_data = {}
    for x, y in data:
        if y in class_data:
            class_data[y].append(x)
        else:
            class_data[y] = [x]
    return class_data


def load_MiniImageNet():
    print("loading MiniImageNet")
    train_data = []
    test_data = []
    resize = T.Resize(size=(32,32))
    data_dir = "data/ImageNet/data"
    for y, class_file in enumerate(os.listdir(data_dir)):
        print("loading class "+str(y))
        class_dir = os.fsdecode(class_file)
        for i, image_file in enumerate(os.listdir(data_dir+"/"+class_dir)):
            image_filename = os.fsdecode(image_file)
            x = resize(read_image(data_dir+"/"+class_dir+"/"+image_filename)).type(torch.float)
            if i < 500:
                train_data.append((x, y))
            else:
                test_data.append((x, y))

    return train_data, test_data


def load_CIFAR100():
    training_data = datasets.CIFAR100(
        root="data",
        train=True,
        download=True,
        transform=ToTensor()
    )
    testing_data = datasets.CIFAR100(
        root="data",
        train=False,
        download=True,
        transform=ToTensor()
    )

    return training_data, testing_data


def load_CIFAR10():
    training_data = datasets.CIFAR10(
        root="data",
        train=True,
        download=True,
        transform=ToTensor()
    )
    testing_data = datasets.CIFAR10(
        root="data",
        train=False,
        download=True,
        transform=ToTensor()
    )

    return training_data, testing_data


def load_MNIST():
    training_data = datasets.MNIST(
        root="data",
        train=True,
        download=True,
        transform=ToTensor()
    )
    testing_data = datasets.MNIST(
        root="data",
        train=False,
        download=True,
        transform=ToTensor()
    )

    return training_data, testing_data


class IncDataStreamBase:

    def __init__(self, train_data, test_data):
        # sort data into bins for each class
        self.class_data = make_class_dict(train_data)

        self.testClassData = make_class_dict(test_data)


class ClassicalIncDataStream(IncDataStreamBase):

    def __init__(self, train_data, test_data, task_length, starting_num_of_classes):
        super().__init__(train_data, test_data)
        self.task_length = task_length
        self._count = 0
        self.remaining_classes = list(self.class_data.keys())
        self.available_classes = []
        for i in range(starting_num_of_classes):
            next_class = random.choice(self.remaining_classes)
            self.remaining_classes.remove(next_class)
            self.available_classes.append(next_class)

    # selects data item by choosing a data point randomly from one of the seen classes or if were at the end of the task
    # also the new class to be added as well
    def next(self):

        if self._count == self.task_length and self.remaining_classes:
            next_class = random.choice(self.remaining_classes)
            self.remaining_classes.remove(next_class)
            self.available_classes.append(next_class)
            self._count = 0

        self._count += 1

        y = random.choice(self.available_classes)
        return random.choice(self.class_data[y]), y


class CIFAR100DataStream(IncDataStreamBase):

    # converts a CIFAR100 fine label into it's super label
    def fine_to_coarse(fine_label):
        fine_lable_to_coarse = {19: 11, 29: 15, 0: 4, 11: 14, 1: 1, 86: 5, 90: 18, 28: 3, 23: 10, 31: 11, 39: 5, 96: 17,
                                82: 2, 17: 9, 71: 10, 8: 18, 97: 8, 80: 16, 74: 16, 59: 17, 70: 2, 87: 5, 84: 6, 64: 12,
                                52: 17, 42: 8, 47: 17, 65: 16, 21: 11, 22: 5, 81: 19, 24: 7, 78: 15, 45: 13, 49: 10,
                                56: 17, 76: 9, 89: 19, 73: 1, 14: 7, 9: 3, 6: 7, 20: 6, 98: 14, 36: 16, 55: 0, 72: 0,
                                43: 8, 51: 4, 35: 14, 83: 4, 33: 10, 27: 15, 53: 4, 92: 2, 50: 16, 15: 11, 18: 7,
                                46: 14, 75: 12, 38: 11, 66: 12, 77: 13, 69: 19, 95: 0, 99: 13, 93: 15, 4: 0, 61: 3,
                                94: 6, 68: 9, 34: 12, 32: 1, 88: 8, 67: 1, 30: 0, 62: 2, 63: 12, 40: 5, 26: 13, 48: 18,
                                79: 13, 85: 19, 54: 2, 44: 15, 7: 7, 12: 9, 2: 14, 41: 19, 37: 9, 13: 18, 25: 6, 10: 3,
                                57: 4, 5: 6, 60: 10, 91: 1, 3: 8, 58: 18, 16: 3}
        return dict[fine_lable_to_coarse]

    # loads CIFAR100 into a dataStream object
    def __init__(self):
        training_data, testing_data = load_CIFAR100()

        super().__init__(training_data, testing_data)


class SplitDataStream:

    # loads data
    def __init__(self, training_data, testing_data, batch_size):
        # sort data into bins for each class
        self.class_data = make_class_dict(training_data)

        self.test_class_data = make_class_dict(testing_data)

        self.batch_size = batch_size

    # default method which should not be used
    def __iter__(self):
        raise NotImplementedError("Please implement an __iter__ method in " + type(self).__name__)


class SplitDatasetStream(SplitDataStream):
    window_data = {}
    test_window_data = {}

    def __init__(self, training_data, testing_data, batch_size, train=True, window_length=5, data_per_class=500):
        super().__init__(training_data, testing_data, batch_size)

        self.train = train
        self.data_per_class = data_per_class

        self.window_length = window_length

        self.classes = [i for i in range(len(self.class_data))]
        #random.shuffle(self.classes)

        # calc training data sampling order
        self._calc_task_data(True)
        # calc testing data sampling order
        self._calc_task_data(False)

    def eval(self):
        self.train = False

    def training(self):
        self.train = True

    def _calc_task_data(self, train):
        tasks = self.window_data if train else self.test_window_data
        class_data = self.class_data if train else self.test_class_data
        for task in range(len(self.classes)//self.window_length):
            new_task_data = list(chain(*[zip_scalar(class_data[self.classes[self.window_length * task + i]][:self.data_per_class],
                                                    self.classes[self.window_length * task + i])
                                         for i in range(self.window_length)]))
            random.shuffle(new_task_data)
            tasks[task] = new_task_data

    def __iter__(self):
        tasks = self.window_data if self.train else self.test_window_data
        for task in range(len(self.classes)//self.window_length):
            data_loader = DataLoader(tasks[task], batch_size=self.batch_size)
            yield data_loader


class SplitCIFAR100(SplitDatasetStream):

    def __init__(self, batch_size, train=True, window_length=5):
        training_data, testing_data = load_CIFAR100()
        super().__init__(training_data, testing_data, batch_size, train, window_length)


class SplitCIFAR10(SplitDatasetStream):

    def __init__(self, batch_size, train=True, window_length=2):
        training_data, testing_data = load_CIFAR10()
        super().__init__(training_data, testing_data, batch_size, train, window_length)


class SplitMiniImageNet(SplitDatasetStream):

    def __init__(self, batch_size, train=True, window_length=2):
        training_data, testing_data = load_MiniImageNet()
        super().__init__(training_data, testing_data, batch_size, train, window_length)


class SplitMNIST(SplitDatasetStream):

    def __init__(self, batch_size, train=True, window_length=2):
        training_data, testing_data = load_MNIST()
        super().__init__(training_data, testing_data, batch_size, train, window_length, data_per_class=-1)


class DummySplitStream(SplitDatasetStream):

    def __init__(self, training_data, testing_data, batch_size, train=True, window_length=5):
        super().__init__(training_data, testing_data, batch_size, train, window_length)

        self.training_data = list(training_data)
        self.testing_data = list(testing_data)

        # shuffle data
        random.shuffle(self.training_data)
        random.shuffle(self.testing_data)

        self.num_of_tasks = len(self.classes) // window_length
        self.num_train_samples_per_task = len(training_data) // self.num_of_tasks
        self.num_of_test_samples_per_task = len(testing_data) // self.num_of_tasks

    def __iter__(self):
        data = self.training_data if self.training else self.testing_data
        step = self.num_train_samples_per_task if self.training else self.num_of_test_samples_per_task
        for task in range(len(self.classes) // self.window_length):
            data_loader = DataLoader(data[task*step:(1+task)*step], batch_size=self.batch_size)
            yield data_loader


class DummySplitCIFAR100(DummySplitStream):

    def __init__(self, batch_size, train=True, window_length=5):
        training_data, testing_data = load_CIFAR100()
        super().__init__(training_data, testing_data, batch_size, train, window_length)


class DummySplitMiniImageNet(DummySplitStream):

    def __init__(self, batch_size, train=True, window_length=5):
        training_data, testing_data = load_MiniImageNet()
        super().__init__(training_data, testing_data, batch_size, train, window_length)


class DummySplitCIFAR10(DummySplitStream):

    def __init__(self, batch_size, train=True, window_length=2, data_per_class=500):
        training_data, testing_data = load_CIFAR10()

        self.class_data = make_class_dict(training_data)
        training_data = []
        for y in self.class_data:
            training_data += zip_scalar(self.class_data[y][:data_per_class], y)

        self.test_class_data = make_class_dict(testing_data)
        testing_data = []
        for y in self.test_class_data:
            testing_data += zip_scalar(self.test_class_data[y][:data_per_class], y)

        super().__init__(training_data, testing_data, batch_size, train, window_length)


class WalkingDirStream(IncDataStreamBase):

    def __init__(self, training_data, testing_data, initial_beta, alpha=1, transition_func=lambda x: x, skew=0):
        super().__init__(training_data, testing_data)
        self.beta = np.array(initial_beta)
        self.alpha = alpha
        self.transition_func = transition_func
        self.skew = skew

    def __iter__(self):
        while True:
            self.beta = np.random.dirichlet(self.alpha*self.transition_func(self.beta)+self.skew)
            y = np.argmax(np.random.multinomial(n=1, pvals=self.beta))
            yield random.choice(self.class_data[y]), y


class WalkingDirMINST(WalkingDirStream):

    def __init__(self, initial_beta, alpha, transition_func, skew):
        training_data, testing_data = load_MNIST()
        super().__init__(training_data, testing_data, initial_beta, alpha, transition_func, skew)


class WalkingDirCIFAR100(WalkingDirStream):

    def __init__(self, initial_beta, alpha, transition_func, skew):
        training_data, testing_data = load_CIFAR100()
        super().__init__(training_data, testing_data, initial_beta, alpha, transition_func, skew)


class ShiftingWindowDataStream(SplitDataStream):

    window_data = {}
    test_window_data = {}

    def __init__(self, training_data, testing_data, batch_size, window_length=5, train=True, shuffle=False,
                 data_per_class=500):
        super().__init__(training_data, testing_data, batch_size)
        self.training = train
        self.window_length = window_length
        self.shuffle = shuffle
        self.data_per_class = data_per_class

        self.classes = list(self.class_data.keys())
        #random.shuffle(self.classes)
        for y, xs in self.class_data.items():
            random.shuffle(xs)

        # must precompute window data so that randomness in each learning algo does not effect data order

        # calc training data sampling order
        self._calc_window_data(True)
        # calc testing data sampling order
        self._calc_window_data(False)

    # train states whether to use testing data or training data
    def _calc_window_data(self, train):
        window = self.window_data if train else self.test_window_data
        class_data = self.class_data if train else self.test_class_data
        for window_index in range(len(self.classes) - self.window_length + 1):
            current_window_data = []
            for i, y in enumerate(self.classes[window_index:window_index + self.window_length]):
                data_chunk_len = len(class_data[y][:self.data_per_class]) // self.window_length
                current_window_data += zip_scalar(class_data[y][i * data_chunk_len:(i + 1) * data_chunk_len], y)
                random.shuffle(current_window_data)
            window[window_index] = current_window_data

    def __iter__(self):
        window = self.window_data if self.training else self.test_window_data
        for window_index in range(len(self.classes)-self.window_length+1):
            yield DataLoader(window[window_index], batch_size=self.batch_size, shuffle=self.shuffle)

    def eval(self):
        self.training = False

    def train(self):
        self.training = True


class CIFAR100ShiftWinStream(ShiftingWindowDataStream):

    def __init__(self, batch_size, window_length=5, train=True, shuffle=False):
        training_data, testing_data = load_CIFAR100()
        super().__init__(training_data, testing_data, batch_size, window_length, train, shuffle)


class CIFAR10ShiftWinStream(ShiftingWindowDataStream):

    def __init__(self, batch_size, window_length=5, train=True, shuffle=False):
        training_data, testing_data = load_CIFAR10()
        super().__init__(training_data, testing_data, batch_size, window_length, train, shuffle)


class MiniImageNetShiftWinStream(ShiftingWindowDataStream):

    def __init__(self, batch_size, window_length=5, train=True, shuffle=False):
        training_data, testing_data = load_MiniImageNet()
        super().__init__(training_data, testing_data, batch_size, window_length, train, shuffle)


class DummyWindowDataStream(ShiftingWindowDataStream):

    def __init__(self, training_data, testing_data, batch_size,
                 training=True, window_length=5, train=True):
        super().__init__(training_data, testing_data, batch_size, window_length, train)
        self.training = training
        self.num_of_tasks = len(self.classes) - self.window_length + 1
        self.train_task_size = len(training_data)//self.num_of_tasks
        self.test_task_size = len(testing_data)//self.num_of_tasks

        self.training_data = list(training_data)
        self.testing_data = list(testing_data)

        #print(list(self.training_data))

        # shuffle data
        random.shuffle(self.training_data)
        random.shuffle(self.testing_data)

    def __iter__(self):
        data = self.training_data if self.training else self.testing_data
        step = self.train_task_size if self.training else self.test_task_size
        for i in range(self.num_of_tasks):
            if i < self.num_of_tasks-1:
                yield DataLoader(data[i*step:(i+1)*step], batch_size=self.batch_size)
            else:
                yield DataLoader(data[i*step:], batch_size=self.batch_size)

    def window_num_of_data(self, x, y, train=True):
        class_data = self.class_data if train else self.test_class_data
        step = self.train_task_size if train else self.test_task_size
        first_task = self.classes.index(y)
        first_task = self.classes.index(y)-self.window_length+1 if first_task > self.window_length-1 else 0
        for i, a in enumerate(class_data[y]):
            if torch.equal(a, x):
                return first_task + i // step


class CIFAR100DummyShiftWinStream(DummyWindowDataStream):

    def __init__(self, batch_size, training=True, window_length=5, train=True):
        training_data, testing_data = load_CIFAR100()
        super().__init__(training_data, testing_data, batch_size, training, window_length, train)


class MiniImageNetDummyShiftWinStream(DummyWindowDataStream):

    def __init__(self, batch_size, training=True, window_length=5, train=True):
        training_data, testing_data = load_MiniImageNet()
        super().__init__(training_data, testing_data, batch_size, training, window_length, train)


class CIFAR10DummyShiftWinStream(DummyWindowDataStream):

    def __init__(self, batch_size, training=True, window_length=5, train=True, data_per_class=500):
        training_data, testing_data = load_CIFAR10()

        self.class_data = make_class_dict(training_data)
        training_data = []
        for y in self.class_data:
            training_data += zip_scalar(self.class_data[y][:data_per_class], y)

        self.test_class_data = make_class_dict(testing_data)
        testing_data = []
        for y in self.test_class_data:
            testing_data += zip_scalar(self.test_class_data[y][:data_per_class], y)

        super().__init__(training_data, testing_data, batch_size, training, window_length, train)


if __name__ == '__main__':
    # train = [(1, 1), (2, 2), (3, 3), (4, 4)]
    # test = list(train)
    # stream = ClassicalIncDataStream(train, test, 20, 2)
    # for i in range(40):
    #    print(str(stream.next()))

    #data_stream = CIFAR100DataStream()
    train, test = load_MiniImageNet()
    print("done")
