import numpy as np
import tensorflow.compat.v1 as tf

tf.disable_v2_behavior()


def get_model(name):
    if name == 'proposed':
        return Model()
    elif name == 'mixed_output':
        return MixedOutputModel()
    elif name == 'direct_input':
        return DirectInputModel()
    elif name == 'normal':
        return NormalModel()
    assert False


def clip_gradient_norms(gradients, max_norm):
    clipped_gradients = []
    for grad in gradients:
        if grad is not None:
            if isinstance(grad, tf.IndexedSlices):
                tmp = tf.clip_by_norm(grad.values, max_norm)
                grad = tf.IndexedSlices(tmp, grad.indices, grad.dense_shape)
            else:
                grad = tf.clip_by_norm(grad, max_norm)
        clipped_gradients.append(grad)
    return clipped_gradients


class AbstractModel(object):
    def initialize(self, args, configs):
        self.args = args
        self.configs = configs

        # create model
        self._create_model()

        # loss
        self._evaluation()

        # optimizer
        self._get_optimizer()

        # initializer
        init = tf.global_variables_initializer()
        self.sess = tf.Session()
        self.sess.run(init)

    def _create_model(self):
        # place holders
        self.x_command = tf.placeholder(tf.int64, shape=(
            None, self.configs['max_command_length']))
        self.x_agent = tf.placeholder(tf.float32, shape=(None, 3))
        self.x_position = tf.placeholder(tf.float32, shape=(
            None, self.configs['max_world_length'], 2))
        self.x_world = tf.placeholder(tf.float32, shape=(
            None, self.configs['max_world_length'], 11))
        self.x_start = tf.placeholder(tf.float32, shape=(None,))
        self.x_command_length = tf.placeholder(tf.int64, shape=(None,))
        self.x_world_length = tf.placeholder(tf.int64, shape=(None,))
        self.y = tf.placeholder(tf.int64, shape=(None, 3))
        self.noise_stddev = tf.placeholder(tf.float32, shape=())

        self.x = [self.x_command, self.x_agent, self.x_position, self.x_world,
                  self.x_start, self.x_command_length, self.x_world_length]
        self.place_holders = self.x + [self.y]

        # masks
        self.command_mask = tf.sequence_mask(
            self.x_command_length, maxlen=self.configs['max_command_length'],
            dtype=tf.float32)
        self.world_mask = tf.sequence_mask(
            self.x_world_length, maxlen=self.configs['max_world_length'],
            dtype=tf.float32)

        # model design
        self.f_size = 32
        self.p_size = 32
        self.reg_list = []
        self.logits = self._create_logits()

    def _create_logits(self):
        assert False

    def _evaluation(self):
        # loss
        self.loss = tf.reduce_mean(
            tf.nn.sparse_softmax_cross_entropy_with_logits(labels=self.y,
                                                           logits=self.logits))
        if not self.args.remove_entreg:
            self.loss += sum(self.reg_list)

        # acc
        self.prediction = tf.argmax(self.logits, -1)
        self.equality = tf.to_float(tf.equal(self.y, self.prediction))
        self.node_acc = tf.reduce_mean(self.equality, 0)
        self.sample_equality = tf.reduce_min(self.equality, -1)
        self.accuracy = tf.reduce_mean(self.sample_equality)

        # target attention accuracy
        target_prediction = tf.argmax(self.target_attention, -1)
        target_equality = tf.to_float(
            tf.equal(tf.to_int64(0), target_prediction))
        self.target_acc = tf.reduce_mean(target_equality)

    def _get_optimizer(self):
        learning_rate = 0.001
        global_step = tf.Variable(0, trainable=False)
        starter_learning_rate = learning_rate
        decay_steps = 100
        decay_base = 0.96
        learning_rate = tf.train.exponential_decay(
            starter_learning_rate,
            global_step,
            decay_steps,
            decay_base,
            staircase=True)

        optimizer = tf.train.AdamOptimizer(learning_rate)
        if self.args.max_gradient_norm > 0:
            params = tf.trainable_variables()
            gradients = tf.gradients(self.loss, params)
            clipped_gradients = clip_gradient_norms(
                gradients, self.args.max_gradient_norm)

            self.optimizer = optimizer.apply_gradients(
                zip(clipped_gradients, params), global_step=global_step)
        else:
            self.optimizer = optimizer.minimize(
                self.loss, global_step=global_step)

    def _feed(self, batch, isTrain):
        feed = {a: b for a, b in zip(self.place_holders, batch)}
        if isTrain:
            feed[self.noise_stddev] = 0.1
        else:
            feed[self.noise_stddev] = 0
        return feed

    def training(self, dl):
        log_steps = 1000
        epochs = 10000
        batch_size = 256

        fetch = [self.optimizer, self.loss, self.accuracy, self.node_acc]
        avg = np.array([0.0 for _ in range(5)])
        for i in range(epochs):
            batch = dl.get_train_data(batch_size)
            assert len(self.place_holders) == len(batch)
            feed = self._feed(batch, True)
            _, loss, acc, node_acc = self.sess.run(fetch, feed_dict=feed)
            avg += np.array([loss, *node_acc, acc])
            if (i + 1) % log_steps == 0:
                avg = avg / log_steps
                print(i + 1, *avg)
                avg = 0 * avg
            if (i + 1) % 10000 == 0:
                print('Evaluation at training step', i + 1)
                self.test(dl)

    def _target_sentence_accuracy(self, sent_results, lengths):
        assert len(sent_results) == len(lengths)
        keys = [14, 15, 16, 17, 18]
        all = {x: 0 for x in keys}
        hit = {x: 0 for x in keys}
        for s, l in zip(sent_results, lengths):
            key = max(min(l, keys[-1]), keys[0])
            all[key] += 1
            if s == 1:
                hit[key] += 1
        ret = []
        for key in keys:
            if all[key] == 0:
                ret.append(-1.0)
            else:
                ret.append(hit[key] / all[key])
        return ret

    def _sentence_accuracy(self, results, lengths):
        assert sum(lengths) == len(results)
        sent_results = []
        offset = 0
        for length in lengths:
            sent_results.append(min(results[offset:offset + length]))
            offset += length
        sent_hits = int(sum(sent_results))
        sent_samples = len(sent_results)
        sent_ret = [sent_hits / sent_samples, sent_hits, sent_samples]
        return sent_ret, sent_results

    def test(self, dl):
        batch_size = 10000
        fetch = [self.loss, self.accuracy, self.node_acc, self.target_acc,
                 self.sample_equality]
        for name in dl.get_names():
            if name == 'train':
                continue
            batch, lengths = dl.get_test_data(name)
            assert len(self.place_holders) == len(batch)
            n_samples = len(batch[0])

            numbers = np.array([0.0] * 6)
            results = []
            offset = 0
            while offset < n_samples:
                last = min(n_samples, offset + batch_size)
                part = [a[offset: last] for a in batch]
                feed = self._feed(part, False)
                loss, acc, d_acc, t_acc, result = self.sess.run(fetch,
                                                                feed_dict=feed)
                numbers += len(part[0]) * np.array([loss, *d_acc, t_acc, acc])
                results.extend(result)
                offset = last
            numbers = numbers / n_samples

            sentence_acc, sent_acc = self._sentence_accuracy(results, lengths)
            if self.args.length:
                sents = self._target_sentence_accuracy(sent_acc, lengths)
                print(name, sents, *numbers, *sentence_acc)
            else:
                print(name, *numbers, *sentence_acc)


class Model(AbstractModel):
    # Entropy regularization
    def _entropy_regularization(self, rep):
        reg = tf.reduce_mean(rep ** 2)
        if not self.args.remove_entreg:
            rep += self.noise_stddev * tf.random_normal(tf.shape(rep))
        return rep, reg

    def _masked_entropy_regularization(self, rep, mask, length):
        reg = tf.reduce_mean(rep ** 2, -1)
        reg = reg * mask
        reg = tf.reduce_sum(reg, -1)
        reg = reg / tf.to_float(length)
        reg = tf.reduce_mean(reg)
        if not self.args.remove_entreg:
            rep += self.noise_stddev * tf.random_normal(tf.shape(rep))
        return rep, reg

    # Attention mechanism
    def _get_scores(self, query, keys):
        query = tf.expand_dims(query, -1)
        scores = tf.matmul(keys, query)
        scores = tf.squeeze(scores, -1)
        return scores

    def _masked_softmax(self, scores, mask):
        ex = tf.exp(scores) * mask
        attention = ex / tf.reduce_sum(ex, -1, keepdims=True)
        return attention

    def _masked_attention(self, query, keys, mask):
        scores = self._get_scores(query, keys)
        return self._masked_softmax(scores, mask)

    def _apply_attention(self, attention, values):
        ret = tf.matmul(tf.expand_dims(attention, -2), values)
        ret = tf.squeeze(ret, -2)
        return ret

    def _masked_retrieval(self, query, keys, values, mask):
        attention = self._masked_attention(query, keys, mask)
        ret = self._apply_attention(attention, values)
        return ret

    # command network
    def _get_embeddings(self, inputs, voc_size, embedding_size):
        embeddings = tf.Variable(
            tf.random_normal([voc_size, embedding_size], stddev=0.3))
        embed = tf.nn.embedding_lookup(embeddings, inputs)

        # regularization
        noisy_embed, reg = self._masked_entropy_regularization(
            embed, self.command_mask, self.x_command_length)
        self.reg_list.append(reg)
        return noisy_embed

    def _command_network(self, rep, max_size, f_size, p_size, out_size):
        functional = self._get_embeddings(rep, max_size, f_size)
        primitive = self._get_embeddings(rep, max_size, p_size)
        query = tf.Variable(tf.random_normal([f_size], stddev=0.3))

        rep = self._masked_retrieval(
            query, functional, primitive, self.command_mask)
        rep = tf.layers.dense(rep, out_size)
        return rep

    # grounding network
    def _expansion_network(self, node):
        layer = tf.layers.dense(node, 16, activation=tf.nn.relu)
        layer = tf.layers.dense(layer, 1)
        return layer

    def _element_expansion(self, layer):
        nodes = tf.split(layer, layer.shape[-1], -1)
        expansions = [self._expansion_network(n) for n in nodes]
        layer = tf.concat(expansions, -1)
        return layer

    def _grounding_network(self):
        score_list = []
        for _ in range(self.args.n_queries):
            rep = self._command_network(
                self.x_command, self.configs['command_voc'], self.f_size,
                self.p_size, self.configs['attribute_vec_size'])

            layer = self._element_expansion(self.x_world)
            layer, reg = self._masked_entropy_regularization(
                layer, self.world_mask, self.x_world_length)
            self.reg_list.append(0.01 * reg)

            score = self._get_scores(rep, layer)
            score_list.append(score)

        scores = sum(score_list)
        scores *= 10
        attention = self._masked_softmax(scores, self.world_mask)
        self.target_attention = attention
        target_position = self._apply_attention(attention, self.x_position)
        return target_position

    # prediction network
    def _one_network(self, pos, start, manner, out_size=1):
        layer = self._element_expansion(pos)
        layer, reg = self._entropy_regularization(layer)
        self.reg_list.append(reg)

        layer = tf.concat([layer, start, manner], -1)
        layer = tf.layers.dense(layer, 32, activation=tf.nn.relu)
        layer = tf.layers.dense(layer, 32, activation=tf.nn.relu)
        layer = tf.layers.dense(layer, out_size)
        return layer

    def _general_network(self, c_size=1):
        keyword = self._command_network(self.x_command,
                                        self.configs['command_voc'],
                                        self.f_size, self.p_size, c_size)
        start = tf.expand_dims(self.x_start, -1)
        pos = self.target_position

        nodes = [self._one_network(pos, start, keyword) for _ in range(4)]
        layer = tf.concat(nodes, -1)
        return layer

    def _create_logits(self):
        """
        Get logits in the order of direction, action and manner
        :return: Tensor 3 x 4, float
        """
        # grounding network
        self.target_position = self._grounding_network()

        # type networks
        y_manner = self._general_network(c_size=8)
        y_action = self._general_network(c_size=8)
        y_direction = self._general_network(c_size=8)

        y_list = [y_direction, y_action, y_manner]
        y_list = [tf.expand_dims(a, -2) for a in y_list]
        logits = tf.concat(y_list, -2)
        return logits


class MixedOutputModel(Model):
    def _general_network(self, c_size=1):
        keyword = self._command_network(self.x_command,
                                        self.configs['command_voc'],
                                        self.f_size, self.p_size, c_size)
        start = tf.expand_dims(self.x_start, -1)
        pos = self.target_position

        layer = self._one_network(pos, start, keyword, 4)
        return layer


class DirectInputModel(Model):
    def _element_expansion(self, layer):
        return layer


class NormalModel(MixedOutputModel):
    def _element_expansion(self, layer):
        return layer
