import tensorflow as tf
import numpy as np
import random

from cfq_data import get_cfq_data


def get_data_generator(args):
    if args.merge_type == 'colored':
        dg = ColoredDataGenerator(args)
    elif args.merge_type == 'disentangled':
        dg = DisentangledDataGenerator(args)
    elif args.merge_type == 'nico':
        dg = NICODataGenerator(args)
    elif args.merge_type == 'celeba':
        dg = CelebADataGenerator(args)
    elif args.merge_type == 'review':
        dg = ReviewDataGenerator(args)
    elif args.merge_type == 'cfq':
        dg = CFQDataGenerator(args)
    else:
        raise ValueError(
            '{0} is not a valid merge_type.'.format(args.merge_type))
    return dg


def one_hot(a, output_nodes):
    ret = [0] * output_nodes
    ret[a] = 1
    return ret


class RandomDataGenerator(object):
    def __init__(self, args):
        self.args = args
        self.output_nodes = self._get_output_nodes()

        self.train_samples, self.test_samples, self.input_shape = self._initialize_data()

        self.train_label_pairs, self.test_label_pairs = self._get_label_splits()

        self._prepare_eval_data(self.train_samples)

    def _initialize_data(self):
        raise NotImplementedError()

    def _get_output_nodes(self):
        raise NotImplementedError()

    def _merge(self, y, y2, samples):
        raise NotImplementedError()

    def _is_train_label(self, x, y):
        assert self.args.label_split == 'diagonal'
        diff = (y - x + self.output_nodes) % self.output_nodes
        return 2 * diff < self.output_nodes

    def _get_label_splits(self):
        train_label_pairs = []
        test_label_pairs = []
        for i in range(self.output_nodes):
            for j in range(self.output_nodes):
                if self._is_train_label(i, j):
                    train_label_pairs.append((i, j))
                else:
                    if self.args.use_all_training_data:
                        train_label_pairs.append((i, j))
                    test_label_pairs.append((i, j))
        return train_label_pairs, test_label_pairs

    def _get_samples(self, samples, k, is_train):
        x_list, y_list, y2_list = [], [], []
        if is_train:
            label_list = random.choices(self.train_label_pairs, k=k)
        else:
            label_list = random.choices(self.test_label_pairs, k=k)

        for y, y2 in label_list:
            x = self._merge(y, y2, samples)
            x_list.append(x)
            y_list.append(one_hot(y, self.output_nodes))
            y2_list.append(one_hot(y2, self.output_nodes))
        x_list = np.asarray(x_list)
        y_list = np.asarray(y_list)
        y2_list = np.asarray(y2_list)
        return x_list, [y_list, y2_list]

    def _prepare_eval_data(self, data):
        eval_data = [[[] for _ in range(self.output_nodes)] for _ in
                     range(self.output_nodes)]
        for i in range(self.output_nodes):
            for j in range(self.output_nodes):
                if self._is_train_label(i, j):
                    samples = data[i][j]
                    random.shuffle(samples)
                    threshold = (9 * len(samples)) // 10
                    data[i][j] = samples[:threshold]
                    eval_data[i][j] = samples[threshold:]
        self.eval_data = eval_data

    def get_training_samples(self, k):
        return self._get_samples(self.train_samples, k, is_train=True)

    def get_training_samples_for_evaluation(self, k):
        return self.get_training_samples(k)

    def get_eval_samples(self, k):
        return self._get_samples(self.eval_data, k, is_train=True)

    def get_test_samples(self, k, randomize=False):
        return self._get_samples(self.test_samples, k, is_train=False)

    def get_test_label_pairs(self):
        return self.test_label_pairs

    def get_input_shape(self):
        return self.input_shape

    def get_output_nodes(self):
        return self.output_nodes

    def get_vocab_size(self):
        return 0


class ColoredDataGenerator(RandomDataGenerator):
    def __init__(self, args):
        colors = [
            [0, 0, 1],
            [0, 1, 0],
            [0, 1, 1],
            [1, 0, 0],
            [1, 0, 1],
            [1, 1, 0],
            [1, 1, 1],
            [0.5, 0, 0],
            [0.5, 0.5, 0],
            [0.5, 0.5, 0.5]
        ]
        colors = np.asarray(colors)
        self.colors = np.expand_dims(colors, 1)

        super().__init__(args)

    def _get_output_nodes(self):
        return 10

    def _get_data(self, data_name):
        if data_name == 'mnist':
            dataset = tf.keras.datasets.mnist
        elif data_name == 'cifar10':
            dataset = tf.keras.datasets.cifar10
        elif data_name == 'fashion_mnist':
            dataset = tf.keras.datasets.fashion_mnist
        else:
            raise ValueError('{0} is not a valid data_name.'.format(data_name))

        (x_train, y_train), (x_test, y_test) = dataset.load_data()
        if len(x_train.shape) == 3:
            x_train = np.expand_dims(x_train, -1)
            x_test = np.expand_dims(x_test, -1)
        shape = x_train.shape[1:]
        train_samples = [x_train, y_train]
        test_samples = [x_test, y_test]
        return train_samples, test_samples, shape

    def _initialize_data(self):
        train, test, shape = self._get_data(self.args.dataset)
        if shape[-1] == 1:
            shape = list(shape)
            shape[-1] = 3
            shape = tuple(shape)

        train = self._prepare_data(train)
        test = self._prepare_data(test)
        return train, test, shape

    def _prepare_data(self, data):
        x_all, y_all = data
        assert len(x_all) == len(y_all)
        x_all = x_all / 255.0 - 0.5
        x_all = x_all.astype("float32")

        data = [[] for _ in range(self.output_nodes)]
        for x, y in zip(x_all, y_all):
            y = int(y)
            data[y].append(x)
        return data

    def _merge(self, y, y2, samples):
        x = random.choice(samples[y])
        x = np.matmul(x, self.colors[y2])
        return x

    def _prepare_eval_data(self, data):
        self.eval_data = self.test_samples


class DisentangledDataGenerator(RandomDataGenerator):
    def _get_output_nodes(self):
        return 10

    def _get_data(self, data_name):
        if data_name == 'mnist':
            dataset = tf.keras.datasets.mnist
        elif data_name == 'cifar10':
            dataset = tf.keras.datasets.cifar10
        elif data_name == 'fashion_mnist':
            dataset = tf.keras.datasets.fashion_mnist
        else:
            raise ValueError('{0} is not a valid data_name.'.format(data_name))

        (x_train, y_train), (x_test, y_test) = dataset.load_data()
        x_train = y_train
        x_test = y_test
        if len(x_train.shape) == 3:
            x_train = np.expand_dims(x_train, -1)
            x_test = np.expand_dims(x_test, -1)
        shape = [2, self.output_nodes]
        train_samples = [x_train, y_train]
        test_samples = [x_test, y_test]
        return train_samples, test_samples, shape

    def _initialize_data(self):
        train, test, shape = self._get_data(self.args.dataset)
        train = self._prepare_data(train)
        test = self._prepare_data(test)
        return train, test, shape

    def _prepare_data(self, data):
        x_all, y_all = data
        assert len(x_all) == len(y_all)
        x_all = x_all.astype("float32")

        data = [[] for _ in range(self.output_nodes)]
        for x, y in zip(x_all, y_all):
            y = int(y)
            data[y].append(x)
        return data

    def _merge(self, y, y2, samples):
        x = np.asarray([one_hot(y, self.output_nodes), one_hot(y2, self.output_nodes)])
        return x

    def _prepare_eval_data(self, data):
        self.eval_data = self.test_samples


class NICODataGenerator(RandomDataGenerator):
    def _initialize_data(self):
        fn = self._get_npy_name()
        data = np.load(fn, allow_pickle=True)
        shape = data[0][0][0].shape
        return data, data, shape

    def _get_output_nodes(self):
        return 5

    def _get_npy_name(self):
        return '../../data/nico/track_1/feat.npy'

    def _merge(self, y, y2, samples):
        x = random.choice(samples[y][y2])
        x = x / 255.0 - 0.5
        return x


class CelebADataGenerator(NICODataGenerator):
    def _get_output_nodes(self):
        return 2

    def _get_npy_name(self):
        return '../../data/celeba/feat.npy'


class ReviewDataGenerator(RandomDataGenerator):
    def _initialize_data(self):
        maxlen = 100
        self.max_length = maxlen + 1
        data, self.vocab_size = self.get_data()
        shape = (self.max_length,)
        return data, data, shape

    def _get_output_nodes(self):
        return 5

    def get_data(self):
        # load
        id_list = [
            'Books_5',
            'Clothing_Shoes_and_Jewelry_5',
            'Home_and_Kitchen_5',
            'Electronics_5',
            'Movies_and_TV_5',
        ]

        id_map = {}
        max_id = 2
        data = [[[] for _ in range(self.output_nodes)] for _ in
                range(self.output_nodes)]
        for category, i in enumerate(id_list):
            fn = '../../data/amazon/' + i + '.tsv'
            print('loading', fn)
            with open(fn, 'r') as f:
                lines = f.readlines()
            random.shuffle(lines)
            lines = lines[:100000]
            for line in lines:
                terms = line.strip().split('\t')
                if len(terms) != 2:
                    continue
                rating = int(terms[0])
                words = terms[1].split(' ')

                x = [1]
                for word in words:
                    if word not in id_map:
                        id_map[word] = max_id
                        max_id += 1
                    x.append(id_map[word])
                y = rating - 1
                data[y][category].append(x)

        return data, max_id

    def _merge(self, y, y2, samples):
        x = random.choice(samples[y][y2])
        padded_x = np.array(x + ([0] * (self.max_length - len(x))))
        return padded_x

    def get_vocab_size(self):
        # Dedicated method for text data.
        return self.vocab_size


class CFQDataGenerator(RandomDataGenerator):
    def __init__(self, args):
        self.args = args

        split_file = "../../data/cfq/splits/" + args.mcd_split + ".json"
        samples, dicts, lengths, maxs = get_cfq_data(split_file)
        X, Y, X_test, Y_test = samples
        voc, act = dicts
        max_input, max_output = maxs

        self.train_samples = list(zip(X, Y))
        self.test_samples = list(zip(X_test, Y_test))
        self.input_shape = max_input
        self.output_nodes = max_output
        self.vocab_size = [len(voc), len(act)]

        self.eval_data = self.test_samples

    def _get_samples(self, samples, k, is_train):
        data_list = random.choices(samples, k=k)
        x_list = [d[0] + d[1] for d in data_list]
        y_list = [d[1] for d in data_list]
        return np.asarray(x_list), np.asarray(y_list)

    def get_vocab_size(self):
        # Dedicated method for text data.
        return self.vocab_size

    def get_test_label_pairs(self):
        return []
