import random
import numpy as np
import tensorflow as tf


def get_data_generator(args, name):
    if args.data_random_seed:
        random.seed(args.data_random_seed)
        np.random.seed(args.data_random_seed)

    if name == 'pair':
        return OverlapDataGenerator(args)
    elif name == 'fmonth':
        return CompoundDataGenerator(args)
    elif name == 'color':
        return ColorDataGenerator(args)
    assert False


class OverlapDataGenerator(object):
    def __init__(self, args):
        self.args = args
        self.output_nodes = self.get_output_nodes()

        (x_train, y_train), (x_test, y_test) = self._load_data()

        self.train_samples = self._separate_data(x_train, y_train)
        self.test_samples = self._separate_data(x_test, y_test)

        self.train_data = self.get_data_group(self.train_samples)
        self.test_data = self.get_data_group(self.test_samples)

        self.train_id_pairs = []
        self.test_id_pairs = []
        for i in range(self.output_nodes[0]):
            for j in range(self.output_nodes[1]):
                if self.is_train_label(i, j):
                    self.train_id_pairs.append([i, j])
                else:
                    self.test_id_pairs.append([i, j])

    def _load_data(self):
        mnist = tf.keras.datasets.mnist
        (x_train, y_train), (x_test, y_test) = mnist.load_data()
        x_train, x_test = x_train / 255.0, x_test / 255.0
        x_train = x_train.astype("float32")
        x_test = x_test.astype("float32")
        return (x_train, y_train), (x_test, y_test)

    def get_data_group(self, samples):
        data = [[] for _ in range(self.output_nodes[0])]
        for sample in samples:
            x, y = sample
            data[y].append(x)
        return data

    def is_train_label(self, y1, y2):
        return (y1 + y2) % 2 == 0

    def _separate_data(self, x_all, y_all):
        data = []
        for x, y in zip(x_all, y_all):
            y = int(y)
            data.append([x, y])
        return data

    def _one_hot(self, a):
        return a
        ret = [0] * self.output_nodes
        ret[a] = 1
        return ret

    def _combine(self, x1, x2):
        x2 = np.transpose(x2)
        return 0.5 * (x1 + x2), x1, x2

    def _get_sample(self, y1, y2, data):
        x1 = random.choice(data[y1])
        x2 = random.choice(data[y2])
        x, x1, x2 = self._combine(x1, x2)

        x = np.expand_dims(x, -1)
        x1 = np.expand_dims(x1, -1)
        x2 = np.expand_dims(x2, -1)
        return x, x1, x2

    def _get_samples(self, data, k, isTrain):
        id_pairs = self.train_id_pairs if isTrain else self.test_id_pairs
        ids = random.choices(id_pairs, k=k)
        x_list, y1_list, y2_list = [], [], []
        x1_list, x2_list = [], []
        for (y1, y2) in ids:
            x, x1, x2 = self._get_sample(y1, y2, data)
            y1_list.append(self._one_hot(y1))
            y2_list.append(self._one_hot(y2))
            x_list.append(x)
            x1_list.append(x1)
            x2_list.append(x2)
        x_list = np.asarray(x_list)
        y_list = [np.asarray(y1_list), np.asarray(y2_list)]
        z_list = [np.asarray(x1_list), np.asarray(x2_list)]
        return x_list, y_list, z_list

    def get_input_shape(self):
        return (28, 28, 1)

    def get_output_nodes(self):
        return [10, 10]

    def get_training_samples(self, k):
        return self._get_samples(self.train_data, k, isTrain=True)[:2]

    def get_eval_samples(self, k):
        return self._get_samples(self.test_data, k, isTrain=True)

    def get_test_samples(self, k):
        return self._get_samples(self.test_data, k, isTrain=False)


class CompoundDataGenerator(OverlapDataGenerator):
    def __init__(self, args):
        super().__init__(args)
        length = self.get_input_shape()[0]
        self.perm = [i for i in range(length)]
        random.shuffle(self.perm)

    def _to_one_hot(self, month, alphabets):
        sample = []
        for c in month:
            index = ord(c) - ord('a')
            sample.append(
                [1. if i == index else 0. for i in range(alphabets)])
        return sample

    def _get_months(self):
        months = [
            'january',
            'february',
            'march',
            'april',
            'may',
            'june',
            'july',
            'august',
            'september',
            'october']
        self.max_length = max([len(x) for x in months]) * 2
        return months

    def _load_data(self):
        months = self._get_months()
        data = []
        alphabets = ord('z') - ord('a') + 2  # 1 additional dim.
        for _ in range(2):
            x = []
            y = []
            for mi, month in enumerate(months):
                sample = self._to_one_hot(month, alphabets)
                x.append(sample)
                y.append(mi)
            data.append((np.asarray(x), np.asarray(y)))
        return data[0], data[1]

    def _combine(self, x1, x2):
        ret = np.concatenate((x1, x2))

        shape = ret.shape
        length = shape[0]
        depth = shape[1]
        patch = np.zeros((self.max_length - length, depth))
        for i in range(len(patch)):
            patch[i][-1] = 1
        ret = np.concatenate((ret, patch))
        return ret, x1, x2

    def get_input_shape(self):
        return (self.max_length, 27, 1)


class ColorDataGenerator(OverlapDataGenerator):
    def get_input_shape(self):
        return (28, 28, 3)

    def get_output_nodes(self):
        return [10, 3]

    def is_train_label(self, y1, y2):
        return y1 % 3 != y2

    def _get_sample(self, y1, y2, data):
        x1 = random.choice(data[y1])
        x1 = np.expand_dims(x1, -1)
        zero = x1 * 0
        if y2 == 0:
            x = np.concatenate([x1, zero, zero], -1)
        elif y2 == 1:
            x = np.concatenate([zero, x1, zero], -1)
        elif y2 == 2:
            x = np.concatenate([zero, zero, x1], -1)
        else:
            assert False
        return x, x1, x1
