import numpy as np
import torch as t
from torch import nn
import random
import os
from utils.dataloader import MNISTMorphDataLoader, ConwayDataLoader, ConwayDataLoaderMutate, PathfinderDataLoader
from argparse import Namespace
device = "cuda" if t.cuda.is_available() else "cpu"

def get_uniform(a,b=None):
    if b is not None:
        return random.randrange(a, b)
    if isinstance(a, int):
        return a
    return random.randrange(*a)

def train_test_split(triples, targets, num_train_samples):
    triples_train = triples[:, :num_train_samples, :]
    targets_train = targets[:, :num_train_samples]
    triples_test = triples[:, num_train_samples:, :]
    targets_test = targets[:, num_train_samples:]
    return triples_train, targets_train, triples_test, targets_test

class DataHandler():

    def __init__(self, dataloader_args={}):
        self.clear_run()

    def init_run(self, CONFIG):
        self.step = 0
        self.current_state = self.init_state

    def advance_run(self, CONFIG):
        self.step += 1

    def reset_run(self):
        self.current_state = self.init_state
        self.step = 0

    def clear_run(self):
        self.init_state = None
        self.current_state = None

    def get_samples(self):
        return self.get_update_samples

class MNISTDataHandler(DataHandler):

    def __init__(self, dataloader_args={}):
        self.dataloader = MNISTMorphDataLoader(**dataloader_args)

    def init_run(self, CONFIG):
        ids = self.dataloader.sample_indices(CONFIG.n_worlds,
                                             num_relations=CONFIG.model_vars["model_inits"]["num_relations"],
                                             num_digits=CONFIG.num_digits)
        self.init_state = ids
        DataHandler.init_run(self, CONFIG)

    def advance_run(self, CONFIG):
        DataHandler.advance_run(self, CONFIG)
        self.current_state = self.dataloader.sample_indices(batch_size=CONFIG.n_worlds,
                                                  num_relations=CONFIG.model_vars["model_inits"]["num_relations"],
                                                  old_indices=self.current_state, shift_min=CONFIG.shift_min,
                                                  shift_max=CONFIG.shift_max)

    def get_samples(self, CONFIG, init=False, sample_test=False):
        num_relations = CONFIG.model_vars["model_inits"]["num_relations"]
        if init:
            num_train_samples = CONFIG.init_samples * num_relations
            run_length = num_train_samples
            if sample_test:
                run_length += CONFIG.init_test_samples * num_relations
        else:
            num_train_samples = get_uniform(CONFIG.train_samples_per_step)
            run_length = num_train_samples + CONFIG.test_samples_per_step
        _, triples, targets, _, _ = self.dataloader.get_run(CONFIG.n_worlds, num_relations=num_relations,
                                                                     num_digits=CONFIG.num_digits, run_length=run_length,
                                                                     indices=self.current_state)
        triples, targets = triples.to(device), targets.to(device)
        if sample_test:
            return train_test_split(triples, targets, num_train_samples)
        return triples, targets

    def get_images(self, time_step, window_size):
        return self.dataloader.get_images(self.current_state)


class ConwayDataHandler(DataHandler):

    def __init__(self, dataloader_args={}):
        self.dataloader = ConwayDataLoader(**dataloader_args)

    def init_run(self, CONFIG):
        worlds = self.dataloader.generate_worlds(CONFIG.n_worlds, CONFIG.model_vars["model_inits"]["num_entities"],
                                                 run_length=CONFIG.world_length)
        self.init_state = worlds
        DataHandler.init_run(self, CONFIG)

    def get_samples(self, CONFIG, init=False, sample_test=False):
        samples_per_step = {}
        if init:
            start = CONFIG.min_init
            end = CONFIG.max_init + 1
            for i in range(start, end):
                num_train_samples = get_uniform(CONFIG.init_samples)
                if sample_test:
                    samples_per_step[i] = {"train": num_train_samples,
                                            "test": CONFIG.init_test_samples}
                else:
                    samples_per_step[i] = num_train_samples
        else:
            start = min(CONFIG.min_test, CONFIG.min_train)
            end = max(CONFIG.max_test, CONFIG.max_train) + 1
            for i in range(start, end):
                train_samples = get_uniform(CONFIG.train_samples_per_step) if CONFIG.min_train <= i <= CONFIG.max_train else 0
                test_samples = get_uniform(CONFIG.test_samples_per_step) if CONFIG.min_test <= i <= CONFIG.max_test else 0
                test_samples = min(test_samples, CONFIG.model_vars["model_inits"]["num_entities"] ** 2 - train_samples)
                samples_per_step[i] = {"train": train_samples,
                                    "test": test_samples}
        result = self.dataloader.get_samples(self.current_state,
                                                       offset=self.step,
                                                       samples_per_step=samples_per_step, grid_restriction=CONFIG.grid_restriction)
        return tuple(r.to(device) for r in result)

    def get_images(self, time_step, window_size):
        return self.current_state[time_step: time_step + window_size].permute([1, 0, 2, 3])


class ConwayDataHandlerMutate(DataHandler):

    def __init__(self, dataloader_args={}):
        self.dataloader = ConwayDataLoaderMutate(**dataloader_args)


    def init_run(self, CONFIG):
        start = min(CONFIG.min_test, CONFIG.min_train)
        end = max(CONFIG.max_test, CONFIG.max_train) + 1
        mutation_counts = CONFIG.mutation_counts
        mutation_counts = [get_uniform(mutation_counts) for _ in range(CONFIG.world_length)]
        worlds, mutations = self.dataloader.generate_worlds(CONFIG.n_worlds, CONFIG.model_vars["model_inits"]["num_entities"],
                                                 run_length=CONFIG.world_length,
                                                 mutation_counts=mutation_counts,
                                                 mutation_step=CONFIG.min_train,
                                                 window=end - start)

        self.init_state = worlds, mutations
        DataHandler.init_run(self, CONFIG)

    def get_samples(self, CONFIG, **kwargs): #ignores kwargs; for compatibility
        start = min(CONFIG.min_test, CONFIG.min_train)
        end = max(CONFIG.max_test, CONFIG.max_train) + 1
        result = self.dataloader.get_samples(self.current_state,
                                                       offset=self.step,
                                                        mutation_step=CONFIG.min_train,
                                                        window=end - start,
                                                        init=self.step == 1)
        return tuple(r.to(device) for r in result)

    def get_images(self, time_step, window_size):
        return self.current_state[0][time_step, :window_size].permute([1, 0, 2, 3])


class PathfinderDataHandler(DataHandler):

    def __init__(self, dataloader_args={}):
        self.dataloader = PathfinderDataLoader(**dataloader_args)

    def init_run(self, CONFIG):
        x, y, indices = self.dataloader.get_worlds(CONFIG.n_worlds)
        self.init_state = x, y, indices



        DataHandler.init_run(self, CONFIG)

    def get_samples(self, CONFIG, **kwargs):
        x, y, indices = self.init_state
        batch_size = x.shape[0]
        grid_area = self.dataloader.grid_size ** 2
        step_area = grid_area // CONFIG.n_updater_steps
        start = grid_area * (self.step - 1) // CONFIG.n_updater_steps
        end = start + step_area
        train_triples, train_targets = self.dataloader.get_samples(x, indices[:, start:end])
        if CONFIG.sampling_schedule == "uniform":
            sub_indices = t.empty(batch_size, start, dtype=t.long)
            for i in range(batch_size):
                t.randperm(start, out=sub_indices[i])
            test_indices = t.gather(indices, 1, sub_indices[:, :step_area].to(device))
        elif CONFIG.sampling_schedule == "complete":
            test_indices = indices[:, :start]
        elif CONFIG.sampling_schedule == "exponential":
            slice_area = step_area
            slice_start = start
            slice_indices = []
            for slice in range(self.step - 1):
                slice_area //= 2
                slice_start -= step_area
                if slice_area == 0:
                    for _ in range(CONFIG.tail_samples):
                        offset = random.randrange(slice_start + step_area)
                        slice_indices += [indices[:, offset: offset + 1]]
                    break
                slice_indices += [indices[:, slice_start + slice_area: slice_start + 2 * slice_area]]
            test_indices = t.cat(slice_indices, dim=1) if slice_indices else indices[:, :0]
        else:
            raise NotImplementedError("sampling schedule not recognized")
        test_triples, test_targets = self.dataloader.get_samples(x, test_indices)
        return train_triples, train_targets, test_triples, test_targets

    def get_seen_images(self, CONFIG):
        x, y, indices = self.init_state
        batch_size = x.shape[0]
        grid_area = self.dataloader.grid_size ** 2
        images = x.float() / 255
        end = grid_area * self.step // CONFIG.n_updater_steps
        mask = indices[:, end:] + t.arange(batch_size, device=device).unsqueeze(1) * grid_area
        images.put_(mask , -t.ones_like(mask, dtype=t.float32), accumulate=False)
        return images


    def get_labels(self, CONFIG, **kwargs):
        x, y, indices = self.init_state
        return y

    def get_images(self, time_step, window_size):
        x, y, indices = self.init_state
        x_view = x.float() / 255
        # x_view = (x > 8).float()
        # print(x_view.mean())
        # return x_view
        y_view = y.unsqueeze(1).unsqueeze(2).unsqueeze(3)
        return x_view * (1 - y_view) + (1 - x_view) * y_view


if __name__ == "__main__":
    PathfinderDataHandler()