import torch
from torchvision import datasets, transforms
import random
import torch.nn.functional as F
from collections import defaultdict
import numpy as np

device = "cuda" if torch.cuda.is_available() else "cpu"

class DataLoader():

    def __init__(self):
        pass

    def get_run(self, *args, **kwargs):
        raise NotImplemented


class MNISTDataset():

    def __init__(self, train=True, device=device):
        transform = transforms.Compose([
            transforms.ToTensor()])
        self.dataset = datasets.MNIST(root='./mnist_data/', train=train, transform=transform, download=True)
        self.images = torch.cat([self.dataset[x][0] for x in range(self.dataset.data.size(0))]).view(-1, 784).to(device)
        self.labels = [self.dataset[x][1] for x in range(self.dataset.data.size(0))]
        self.labels_tensor = torch.LongTensor([self.dataset[x][1] for x in range(self.dataset.data.size(0))]).to(device)

        indices_by_label = [[] for _ in range(10)]
        for i, l in enumerate(self.labels):
            indices_by_label[l] += [i]
        self.len_indices_by_label = torch.LongTensor([len(indices_by_label[i]) for i in range(10)]).to(device)
        ibl = [torch.LongTensor(indices_by_label[i]).to(device) for i in range(10)]
        self.indices_by_label = torch.nn.utils.rnn.pad_sequence(ibl, batch_first=True, padding_value=0)
        self.permutations = torch.zeros_like(self.images, dtype=torch.long)
        for i in range(self.images.size(0)):
            torch.randperm(self.images.size(1), out=self.permutations[i])
        self.permutations = self.permutations.to(device)


class MNISTDataLoader():

    def __init__(self, train=True, device=device):
        DataLoader.__init__(self)
        self.datasets = {True: MNISTDataset(True, device), False: MNISTDataset(False, device)}
        self.device = device
        self.train = train




    '''
    sample indices that aren't the same as before. Leave old_indices as None for the first run
    shift_max is inclusive
    '''
    def sample_indices(self, batch_size=None, num_relations=2, num_digits=10, old_indices=None, shift_min=1, shift_max=3):
        dataset = self.datasets[self.train]
        indices_by_label = dataset.indices_by_label
        if old_indices is not None:
            batch_size = len(old_indices)
            num_relations = len(old_indices[0])
        selected = []

        if old_indices is not None:
            old_first_label = dataset.labels_tensor[old_indices[:, 0]]
            next_digit = torch.randint(shift_min, shift_max + 1, size=(batch_size,), device=self.device) + old_first_label
            next_digit = (next_digit.unsqueeze(1) + torch.arange(num_relations, device=device)) % num_digits
        else:
            next_digit = (torch.randint(num_digits, size=(batch_size, 1), device=self.device) +
                          torch.arange(num_relations, device=device)) % num_digits
        idx = ((torch.rand((batch_size, num_relations), device=self.device) * dataset.len_indices_by_label[next_digit]).long())
        return indices_by_label[next_digit, idx]


        for i in range(batch_size):
            if old_indices:
                old_first_label = dataset.labels[old_indices[i][0]]
                next_digit = (random.randrange(shift_min, shift_max + 1) + old_first_label) % num_digits
            else:
                next_digit = random.randrange(num_digits)

            indices_tuple = []
            for j in range(num_relations):
                possibilities = indices_by_label[next_digit]
                # if old_indices:
                #     possibilities = possibilities - {old_indices[i][j]}
                indices_tuple += [possibilities[random.randrange(dataset.len_indices_by_label[next_digit])]]
                next_digit = (next_digit + 1) % num_digits
            selected += [tuple(indices_tuple)]

        return selected


    '''
    returns image id, relations triple tensor, and target value tensor
    relation tuple is (x,r,y)
    batch_size number of universes,  run_length number of relations in a universe for update/query purposes
    can specify indices of desired images
    train or test dataset can be used
    if held_out is an integer n, then a random set of n pixels are never sampled,
    enforced consistently between different runs of the same universe
    n can be decreased to expose the model to progressively more pixels in the held out set
    TODO: fancier held_out sampling methods
    if shuffle is false then the canonical order is returned (i.e. the test pixels are the end).
    shuffle_entities = 1 -> x and y are shuffled together
    shuffle_entities = 2 -> x and y are shuffled independently
    '''

    def get_run(self, batch_size, indices=None, num_relations=2, num_digits=10, run_length=784, held_out=0, shuffle=False, shuffle_entities=0):
        triples = torch.zeros(batch_size, run_length, 3, dtype=torch.long, device=self.device)
        targets = torch.zeros(batch_size, run_length, device=self.device)
        dataset = self.datasets[self.train]
        images = dataset.images
        permutations = dataset.permutations
        if indices is not None:
            selected = indices
        else:
            selected = self.sample_indices(batch_size=batch_size, num_relations=num_relations, num_digits=num_digits)
        selected_transpose = [list(i) for i in zip(*selected)]
        for relation in range(num_relations):
            relation_index = torch.LongTensor(selected_transpose[relation]).to(self.device)
            relation_length = (run_length + num_relations - 1 - relation) // num_relations
            if shuffle: #same shuffle for all batch elements, should be random enough
                indices = torch.randperm(784 - held_out)[:relation_length]
                pixels = permutations[relation_index][:, indices]
            else:
                pixels = permutations[relation_index, :relation_length]
            triples[:, relation::num_relations, 0] = pixels // 28
            triples[:, relation::num_relations, 2] = pixels % 28
            triples[:, relation::num_relations, 1] = relation
            pixel_offset = relation_index.unsqueeze(1) * images.size(1) + pixels

            targets[:, relation::num_relations] = images.flatten()[pixel_offset.flatten()].view(batch_size, relation_length)
        if shuffle_entities:
            y_perms = torch.zeros(batch_size, 28, dtype=torch.long, device=self.device)
            for i in range(batch_size):
                torch.randperm(28, out=y_perms[i])
            if shuffle_entities == 2:
                x_perms = torch.zeros(batch_size, 28, dtype=torch.long, device=self.device)
                for i in range(batch_size):
                    torch.randperm(28, out=x_perms[i])
            else:
                x_perms = y_perms
            triples[:, :, 0] = torch.gather(y_perms, 1, triples[:, :, 0])
            triples[:, :, 2] = torch.gather(x_perms, 1, triples[:, :, 2])
            return selected, triples, targets, y_perms, x_perms
        return selected, triples, targets, None, None

    def get_images(self, ids):
        images = self.datasets[self.train].images
        return torch.stack([images[torch.LongTensor(x)] for x in zip(*ids)], dim=1)

    def get_labels(self, ids):
        labels = self.datasets[self.train].labels_tensor
        return torch.stack([labels[torch.LongTensor(x)] for x in zip(*ids)], dim=1)


class MNISTMorphDataLoader(MNISTDataLoader):

    def __init__(self, threshold=.01, sample_warning=True, train=True, device=device):
        MNISTDataLoader.__init__(self, train, device)
        for key, dataset in self.datasets.items():
            dataset.images = dataset.images.ge(threshold).float()
            dataset.on_set = [set(dataset.images[i].nonzero().flatten().tolist()) for i in range(dataset.images.shape[0])]
        self.sample_warning = sample_warning

    '''
    sample_counts should be format (n00,n01,n10,n11) where x is the value of the old pixel,
    y is the value of the new pixel, and nxy is the number of pixels desired of that type.
    
    '''
    def get_morphs(self, sample_counts, batch_size=None, num_relations=2, num_digits=10, shift_min=1, shift_max=3,
                   old_indices=None, new_indices=None,  train=True):
        raise NotImplemented # need to refactor the dataset access
        if old_indices is None:
            old_indices = self.sample_indices(batch_size, num_relations=num_relations, train=train)
        else:
            batch_size = len(old_indices)
            num_relations = len(old_indices[0])
        if new_indices is None:
            new_indices = self.sample_indices(batch_size, num_relations=num_relations, num_digits=num_digits,
                                              shift_min=shift_min, shift_max=shift_max, old_indices=old_indices, train=train)
        run_length = sum(sample_counts)
        triples = torch.zeros(batch_size, run_length, 3, dtype=torch.long, device=self.device)
        targets = torch.zeros(batch_size, run_length, device=self.device)
        images = self.train_images if train else self.test_images
        on_set = self.train_on_set if train else self.test_on_set
        all_set = set(range(784))
        for batch_index, (old_index, new_index) in enumerate(zip(old_indices, new_indices)):
            s11 = []
            s10 = []
            s01 = []
            s00 = []

            for r in range(num_relations):
                old_set = on_set[old_index[r]]
                new_set = on_set[new_index[r]]
                s11 += [(x, r) for x in old_set & new_set]
                s10 += [(x, r) for x in old_set - new_set]
                s01 += [(x, r) for x in new_set - old_set]
                s00 += [(x, r) for x in all_set - new_set - old_set]
            p11 = torch.LongTensor(s11)
            p10 = torch.LongTensor(s10)
            p01 = torch.LongTensor(s01)
            p00 = torch.LongTensor(s00)
            l11 = len(p11)
            l01 = len(p01)
            l10 = len(p10)
            l00 = len(p00)
            zero_cats = sum([x == 0 for x in (l00, l01, l10, l11)])
            if zero_cats > 0 and self.sample_warning:
                print("pictures with index{} and {} have no pixels of {} # category(ies)".format(old_index, new_index, zero_cats))
            if len(p11) == 0:
                p11 = p00
            if len(p00) == 0:
                p00 = p11
            if len(p01) == 0:
                p01 = p10
            if len(p10) == 0:
                p10 = p01
            if len(p11) == 0 or len(p01) == 0:
                print("pictures with index{} and {} have no pixels of multiple categories".format(old_index, new_index))
                raise IndexError
            sample_pools = [p00, p01, p10, p11]

            samples = torch.cat([pool[torch.randint(len(pool), (num,), dtype=torch.int64, device=self.device)]
                                        for num, pool in zip(sample_counts, sample_pools)])
            sample_relations = samples[:, 1]
            triples[batch_index, :, 1] = sample_relations
            pixels = samples[:, 0]
            triples[batch_index, :, 0] = pixels // 28
            triples[batch_index, :, 2] = pixels % 28
            sample_indices = [new_index[x] for x in sample_relations]
            targets[batch_index] = images[sample_indices, pixels]
        return new_indices, triples, targets




conway_weights = torch.tensor([[1,1,1], [1,10,1], [1,1,1]]).view(1, 1, 3, 3).type(torch.uint8)

def torch_init(batch_size, N, p=0.2, generator=None):
    samples = torch.rand((batch_size, N, N), generator=generator) if generator is not None else torch.rand((batch_size, N, N))
    return (samples < p).type(torch.uint8)


def torch_step(grids):
    grids = F.pad(grids.unsqueeze(1), (1, 1, 1, 1), mode='circular')
    newboard = F.conv2d(grids, conway_weights).squeeze()
    newboard = (newboard==12) | (newboard==3) | (newboard==13)
    return newboard.type(torch.uint8)


def glider_init(batch_size, N):
    samples = torch.zeros(batch_size, N, N).type(torch.uint8)
    samples[:, 0, 0] = 1
    samples[:, 0, 1] = 1
    samples[:, 0, 2] = 1
    samples[:, 1, 0] = 1
    samples[:, 2, 1] = 1
    return samples

class ConwayDataLoader(DataLoader):

    #if batch_count is not none, then batches are drawn repeatedly from a pool of batches.
    def __init__(self, batch_count=1000, gliders_only=False):
        self.batch_count = batch_count
        if self.batch_count:
            self.generator = torch.Generator()
            self.reset_epoch()
        self.gliders_only = gliders_only




    def reset_epoch(self):
        self.batches_left = list(range(self.batch_count))
        random.shuffle(self.batches_left)

    def set_batch_generator(self):
        if not self.batches_left:
            self.reset_epoch()
        self.generator.manual_seed(self.batches_left.pop())


    def generate_worlds(self, batch_size=100, grid_size=32, run_length=50, boring_threshold=10, train=True):
        if train and self.batch_count:
            self.set_batch_generator()
            generator = self.generator
        else:
            generator = None
        if self.gliders_only:
            grids = glider_init(batch_size, grid_size)
        else:
            grids = torch_init(batch_size, grid_size, generator=generator)

        res = [grids]
        for i in range(run_length):
            grids = torch_step(grids)
            res += [grids]
        return torch.stack(res).to(device)


    #example sample_per_step = {0:100} samples 100 pixels at time step 0
    #sample_per_step = {0:{"train":100, "test":50}} samples 150 pixels at time step 0 and splits into train and test.
    # returns either (triples, targets) or (train_triples, train_targets, test_triples, test_targets)
    # triples has shape (batch_size, sample_count, 3) and targets has shape (batch_size, sample_count) where sample_count
    # is aggregated over all time steps.
    def get_samples(self, worlds, offset=0, samples_per_step={}, grid_restriction=False):
        train_triples = []
        train_targets = []
        test_triples = []
        test_targets = []
        steps, batch_size, y, x = worlds.shape
        grid_area = x * y

        for step, samples in samples_per_step.items():
            if isinstance(samples, int):
                train_samples = samples
                test_samples = 0
            else:
                train_samples = samples["train"]
                test_samples = samples["test"]
            total_samples = train_samples + test_samples
            if grid_restriction:
                side = torch.arange(*grid_restriction, device=device)
                indices = (side.unsqueeze(1) + side * x).flatten().repeat(batch_size, 1)
            else:
                indices = torch.empty(batch_size, grid_area, device=device, dtype=torch.long)
                for i in range(batch_size):
                    torch.randperm(grid_area, out=indices[i])
                indices = indices[:, :total_samples]
            triples = torch.zeros(batch_size, indices.shape[1], 3, dtype=torch.long, device=device)
            triples[:, :, 0] = indices // x
            triples[:, :, 2] = indices % x
            triples[:, :, 1] = step

            targets = torch.gather(worlds[step + offset].view(batch_size, -1), 1, indices).float()
            train_triples += [triples[:, :train_samples]]
            train_targets += [targets[:, :train_samples]]
            if test_samples:
                test_triples += [triples[:, train_samples:]]
                test_targets += [targets[:, train_samples:]]
        if test_targets:
            return torch.cat(train_triples, dim=1), torch.cat(train_targets, dim=1), torch.cat(test_triples, dim=1), torch.cat(test_targets, dim=1)
        return torch.cat(train_triples, dim=1), torch.cat(train_targets, dim=1)




def all_triples_tensor(batch_size, num_entities=28):
    pixels = torch.arange(num_entities ** 2)
    triples = torch.zeros(num_entities ** 2, 3, dtype=torch.long)
    triples[:, 0] = pixels // num_entities
    triples[:, 2] = pixels % num_entities
    return torch.stack([triples] * batch_size, dim=0)


class ConwayDataLoaderMutate(ConwayDataLoader):


    def mutate(self, grids, mutation_count, mutation_step, generator=None):
        batch_size, x, y = grids.shape
        grid_area = x * y
        indices = torch.empty(batch_size, grid_area, dtype=torch.long)
        for i in range(batch_size):
            torch.randperm(grid_area, generator=generator, out=indices[i])
        unchanged_indices = indices[:, mutation_count:].to(device)
        indices = indices[:, :mutation_count]
        triples = torch.zeros(batch_size, mutation_count, 3, dtype=torch.long)
        triples[:, :, 0] = indices // x
        triples[:, :, 1] = mutation_step
        triples[:, :, 2] = indices % x
        targets = torch.randint(2, (batch_size, mutation_count),  generator=generator).float()
        grids = grids.detach().clone()
        grids.put_(torch.arange(batch_size).unsqueeze(1) * grid_area + indices, targets.type(torch.uint8), accumulate=False)
        return grids, (triples, targets, unchanged_indices)


    def generate_worlds(self, batch_size=100, grid_size=32, run_length=50, mutation_counts=0, mutation_step=1, window=3, train=True):
        if train and self.batch_count:
            self.set_batch_generator()
            generator = self.generator
        else:
            generator = None
        grids = torch_init(batch_size, grid_size, generator=generator)
        trajectory = [grids]
        for _ in range(window - 1):
            grids = torch_step(grids)
            trajectory += [grids]
        trajectories = [trajectory]
        mutations = [[]]
        for i in range(run_length):
            trajectory = trajectories[-1][1:mutation_step + 1]
            mutation_count = mutation_counts if isinstance(mutation_counts, int) else mutation_counts[i - 1]
            mutated_step, mutation = self.mutate(trajectories[-1][mutation_step + 1], mutation_count, mutation_step, generator=generator)
            mutations += [mutation]
            trajectory += [mutated_step]
            grids = mutated_step
            for _ in range(window - 1 - mutation_step):
                grids = torch_step(grids)
                trajectory += [grids]
            trajectories += [trajectory]
        return torch.stack([torch.stack(t) for t in trajectories]).to(device), mutations

    def get_samples(self, worlds, offset=0, mutation_step=1, window=3, init=False):
        grids, mutations = worlds
        grids, mutations = grids[offset], mutations[offset]
        batch_size = grids.shape[1]
        num_entities = grids.shape[2]
        if init:
            mutated_triples = all_triples_tensor(batch_size, num_entities)
            mutated_triples[:, :, 1] = mutation_step
            mutated_targets = grids[mutation_step].view(batch_size, num_entities ** 2).float()
        else:
            mutated_triples, mutated_targets, unchanged_indices = mutations
            unchanged_triples = torch.zeros(batch_size, unchanged_indices.shape[1], 3, dtype=torch.long)
            unchanged_triples[:, :, 0] = unchanged_indices // num_entities
            unchanged_triples[:, :, 1] = mutation_step
            unchanged_triples[:, :, 2] = unchanged_indices % num_entities
            unchanged_targets = torch.gather(grids[mutation_step].view(batch_size, -1), 1, unchanged_indices).float()
        all_triples = torch.stack([all_triples_tensor(batch_size, num_entities)] * (window - 1), dim=1)
        all_targets = []
        for i in range(window):
            if i == mutation_step:
                continue
            idx = i if i < mutation_step else i - 1
            all_triples[:, idx, :, 1] = i
            all_targets += [grids[i].view(batch_size, num_entities ** 2).float()]
        all_triples = all_triples.view(batch_size, -1, 3)
        if init:
            test_triples = all_triples
            test_targets = torch.cat(all_targets, dim=1)
        else:
            test_triples = torch.cat([all_triples, unchanged_triples], dim=1)
            test_targets = torch.cat(all_targets + [unchanged_targets], dim=1)
        return mutated_triples, mutated_targets, test_triples, test_targets

def test_morphs():
    dl = MNISTMorphDataLoader()
    batch_size = 5
    old_indices = dl.sample_indices(batch_size)
    n = 10
    new_indices, triples, targets = dl.get_morphs([n] * 4, old_indices=old_indices)
    pixels = triples[:, :, 0] * 28 + triples[:, :, 2]
    old_images = dl.get_images(old_indices)
    new_images = dl.get_images(new_indices)
    for b in range(batch_size):
        old_targets = old_images[b, triples[b, :, 1], pixels[b]]
        new_targets = new_images[b, triples[b, :, 1], pixels[b]]
        print((new_targets == targets[b]).all())
        old_targets_check = torch.LongTensor([0] * (2 * n) + [1] * (2 * n), device=device)
        print((old_targets == old_targets_check).all())
    pass


def test_mutate():
    c = ConwayDataLoaderMutate()
    grids = torch.zeros(3, 4, 4).type(torch.uint8)
    # print(c.mutate(grids, 5))
    grids = torch.ones(3, 4, 4).type(torch.uint8)
    # print(c.mutate(grids, 5))
    grids, mutations = c.generate_worlds(10, 4, 5, 3, 1)
    grids = grids.to("cpu")
    for i in range(grids.shape[0]):
        grid = grids[i]
        for j in range(grids.shape[1]-1):
            # print((grid[j]-grid[j+1]).abs().sum(dim=[1,2]))
            print((torch_step(grid[j]).float()-grid[j+1].float()).abs().sum(dim=[1,2]))
    for j in range(grids.shape[1]):
        grid = grids[i]
        for i in range(grids.shape[0]-1):
            # print((grid[j]-grid[j+1]).abs().sum(dim=[1,2]))
            print(i,j,(torch_step(grids[i, j]).float()-grids[i+1, j].float()).abs().sum(dim=[1,2]))



PATHFINDER_PATH = "../pathfinderhard/"


class PathfinderDataLoader(DataLoader):

    def __init__(self, grid_size=32, split="train", shuffle=True, device=device):
        DataLoader.__init__(self)
        self.grid_size = grid_size
        self.device = device
        self.split = split
        dataset = np.load(PATHFINDER_PATH + "pathfinder{}hard.npz".format(grid_size))
        self.data = {}
        self.data["x"] = dataset[split+"_x"]
        self.data["y"] = dataset[split+"_y"]
        self.shuffle = shuffle
        self.next_idx = 0
        # for split in ["train", "val", "test"]:
        #     self.data[split+"_x"] = dataset[split+"_x"]
        #     self.data[split+"_y"] = dataset[split+"_y"]

    def get_worlds(self, batch_size):
        l = self.data["y"].shape[0]
        if self.shuffle:
            idx = np.random.choice(l, batch_size, replace=False)
        else:
            end = self.next_idx + batch_size
            idx = np.arange(self.next_idx, end) % l
            self.next_idx = end
        grid_area = self.grid_size ** 2
        indices = torch.empty(batch_size, grid_area, dtype=torch.long)
        for i in range(batch_size):
            torch.randperm(grid_area, out=indices[i])
        x, y = torch.from_numpy(self.data["x"][idx]).to(device), torch.from_numpy(self.data["y"][idx]).to(device).float()
        return x, y, indices.to(device)

    def get_samples(self, worlds, indices):
        batch_size = worlds.shape[0]
        triples = torch.zeros(batch_size, indices.shape[1], 3, dtype=torch.long, device=self.device)
        triples[:, :, 0] = indices // self.grid_size
        triples[:, :, 2] = indices % self.grid_size
        targets = torch.gather(worlds.view(batch_size, -1), 1, indices)
        targets = targets.float() / 255
        return triples, targets


def test_pathfinder():
    pdl = PathfinderDataLoader(32)
    ret = pdl.get_samples(pdl.get_worlds(256)[0])
    print(ret[0], ret[1])



if __name__ == "__main__":
    test_pathfinder()
    test_mutate()
    exit()