import numpy as np
import tensorflow.compat.v1 as tf

tf.disable_v2_behavior()

from tools import save_images
from tools import save_text
from tools import save_hidden


def get_algorithm(args, dg):
    if args.parameter_random_seed > 0:
        tf.random.set_random_seed(args.parameter_random_seed)

    in_shape = dg.get_input_shape()
    out_shape = dg.get_output_nodes()
    if args.algorithm == 'aux':
        model = MNISTModel(args, in_shape, out_shape)
    elif args.algorithm == 'fmonth':
        model = MonthModel(args, in_shape, out_shape)
    elif args.algorithm == 'color':
        model = ColorModel(args, in_shape, out_shape)
    else:
        assert False
    alg = Algorithm(args, dg, model)
    return alg


class Algorithm(object):
    def __init__(self, args, dg, model):
        self.args = args
        self.dg = dg
        self.model = model
        self.eval_data = dg.get_eval_samples(args.test_size)
        self.test_data = dg.get_test_samples(args.test_size)

        memory_batch = self.args.memory_size // self.args.batch_size
        if self.args.memory_size % self.args.batch_size > 0:
            memory_batch += 1

        self.memory_data = [dg.get_training_samples(args.test_size) for _ in
                            range(memory_batch)]

    def get_memory(self):
        memory1 = []
        memory2 = []
        Y1_list = []
        Y2_list = []
        for md in self.memory_data:
            X, [Y1, Y2] = md
            m1, m2 = self.model.extraction(X)
            memory1.extend(m1)
            memory2.extend(m2)
            Y1_list.extend(Y1)
            Y2_list.extend(Y2)

        memory1 = memory1[:self.args.memory_size]
        memory2 = memory2[:self.args.memory_size]
        Y1_list = Y1_list[:self.args.memory_size]
        Y2_list = Y2_list[:self.args.memory_size]

        memory = [memory1, memory2]
        return memory, [Y1_list, Y2_list]

    def run(self):
        self.model.run(self.dg)

        memory, _ = self.get_memory()

        print('Normal test on original distribution:',
              *self.model.test_normal(self.eval_data, 'eval_normal'))
        print('Normal test on changed distribution:',
              *self.model.test_normal(self.test_data, 'test_normal'))
        print('Test on original distribution:',
              *self.model.test(self.eval_data, memory, 'eval'))
        print('Test on changed distribution:',
              *self.model.test(self.test_data, memory, 'test'))


class MNISTModel(object):
    def __init__(self, args, input_shape, output_nodes):
        self.args = args
        self.input_shape = input_shape
        self.output_nodes = output_nodes

        # model creation
        self._create_model()

        # initialization
        init = tf.global_variables_initializer()
        self.sess = tf.Session()
        self.sess.run(init)

    def _get_encoder(self, x):
        x = tf.keras.layers.Conv2D(32, (3, 3), activation='relu')(x)
        x = tf.keras.layers.MaxPooling2D((2, 2))(x)
        x = tf.keras.layers.Conv2D(64, (3, 3), activation='relu')(x)
        x = tf.keras.layers.Flatten()(x)
        x = tf.keras.layers.Dense(self.args.n_hidden_nodes)(x)
        return x

    def _get_encoders(self, X):
        h1 = self._get_encoder(X)
        h2 = self._get_encoder(X)
        return [h1, h2]

    def _get_decoder(self, x, output_depth=1):
        x = tf.keras.layers.Dense(units=7 * 7 * 32, activation=tf.nn.relu)(x)
        x = tf.keras.layers.Reshape(target_shape=(7, 7, 32))(x)
        x = tf.keras.layers.Conv2DTranspose(
            filters=64,
            kernel_size=3,
            strides=(2, 2),
            padding="SAME",
            activation='relu')(x)
        x = tf.keras.layers.Conv2DTranspose(
            filters=32,
            kernel_size=3,
            strides=(2, 2),
            padding="SAME",
            activation='relu')(x)
        # No activation
        x = tf.keras.layers.Conv2DTranspose(filters=output_depth,
                                            kernel_size=3, strides=(1, 1),
                                            padding="SAME")(x)
        return x

    def _get_decoders(self, h1, h2):
        X1_hat = self._get_decoder(h1)
        X2_hat = self._get_decoder(h2)
        X_hat = (X1_hat + X2_hat) / 2
        return X_hat, X1_hat, X2_hat

    def _entropy_regularization(self, h):
        reg = tf.nn.l2_loss(h) / (self.k * self.args.batch_size)
        h += self.args.sigma * self.add_noise * tf.random.normal(h.shape)
        return h, reg

    def _manifold_regularization(self, h, m):
        # compute l2 distance
        he = tf.expand_dims(h, 1)
        me = tf.expand_dims(m, 0)
        diff = he - me
        square = tf.square(diff)
        l2 = tf.reduce_mean(square, -1)

        # find nearest neighbor
        nearest = tf.reduce_min(l2, -1)

        # compute mean as loss
        loss = tf.reduce_mean(nearest)
        return loss

    def _create_model(self):
        self.k = self.args.n_hidden_nodes
        self.input_depth = self.input_shape[2]
        self.input_size = np.prod(self.input_shape)
        self.use_encoder = tf.placeholder(shape=(), dtype=tf.float32)
        self.add_noise = tf.placeholder(shape=(), dtype=tf.float32)

        self.X = tf.placeholder(shape=(None, *self.input_shape),
                                dtype=tf.float32)
        self.Y1 = tf.placeholder(shape=(None,), dtype=tf.int64)
        self.Y2 = tf.placeholder(shape=(None,), dtype=tf.int64)

        self.m1 = tf.placeholder(shape=(None, self.k), dtype=tf.float32)
        self.m2 = tf.placeholder(shape=(None, self.k), dtype=tf.float32)

        self.h_train = self._get_encoders(self.X)

        self.h_test = [
            tf.Variable(
                np.zeros((self.args.batch_size, self.args.n_hidden_nodes)),
                dtype=tf.float32) for _ in self.h_train]

        h_before = [self.use_encoder * tr + (1 - self.use_encoder) * ts for
                    tr, ts in zip(self.h_train, self.h_test)]
        self.h1 = h_before[0]
        self.h2 = h_before[1]

        self.train_losses = []
        self.inference_losses = []

        # entropy regularization
        if self.args.use_entropy_regularization_train:
            h1, entropy_reg1 = self._entropy_regularization(h_before[0])
            h2, entropy_reg2 = self._entropy_regularization(h_before[1])
            h_after = [h1, h2]
            self.entropy_reg = entropy_reg1 + entropy_reg2
            self.train_losses.append(self.args.alpha * self.entropy_reg)
            if self.args.use_entropy_regularization_inference:
                self.inference_losses.append(
                    self.args.alpha * self.entropy_reg)
        else:
            h_after = h_before

        # manifold regularization
        if self.args.use_manifold_regularization:
            manifold_reg1 = self._manifold_regularization(h_after[0], self.m1)
            manifold_reg2 = self._manifold_regularization(h_after[1], self.m2)
            self.manifold_reg = manifold_reg1 + manifold_reg2
            self.inference_losses.append(self.args.gamma * self.manifold_reg)

        # reconstruction network
        if self.args.use_reconstruction_train:
            self.X_hat, self.X1_hat, self.X2_hat = self._get_decoders(
                h_after[0], h_after[1])
            self.reconstruction_loss = tf.nn.l2_loss(
                self.X - self.X_hat) / self.args.batch_size
            self.train_losses.append(self.args.beta * self.reconstruction_loss)
            if self.args.use_reconstruction_inference:
                self.inference_losses.append(
                    self.args.beta * self.reconstruction_loss)

        # prediction network
        s1 = tf.layers.dense(h_after[0], self.output_nodes[0])
        s2 = tf.layers.dense(h_after[1], self.output_nodes[1])

        loss1 = tf.nn.sparse_softmax_cross_entropy_with_logits(
            labels=self.Y1, logits=s1)
        loss2 = tf.nn.sparse_softmax_cross_entropy_with_logits(
            labels=self.Y2, logits=s2)
        self.prediction_loss = tf.reduce_mean(loss1 + loss2)
        self.train_losses.append(self.prediction_loss)

        # accuracy
        self.pred1 = tf.argmax(s1, -1)
        self.pred2 = tf.argmax(s2, -1)
        equal1 = tf.to_float(tf.equal(self.Y1, self.pred1))
        equal2 = tf.to_float(tf.equal(self.Y2, self.pred2))
        equal = equal1 * equal2
        self.acc1 = tf.reduce_mean(equal1)
        self.acc2 = tf.reduce_mean(equal2)
        self.acc = tf.reduce_mean(equal)

        # transfer error rate
        error = 1 - equal
        even = 1 - (self.pred1 + self.pred2) % 2
        even_error = error * tf.to_float(even)
        self.even_error_ratio = tf.reduce_mean(even_error)
        self.error_ratio = tf.reduce_mean(error)
        self.ter = self.even_error_ratio / self.error_ratio

        # training optimizer
        self.lr = tf.placeholder(shape=(), dtype=tf.float32)
        self.loss = sum(self.train_losses)
        optimizer = tf.train.AdamOptimizer(self.lr)
        if self.args.max_gradient_norm <= 0:
            self.optimizer = optimizer.minimize(self.loss)
        else:
            params = tf.trainable_variables()
            gradients = tf.gradients(self.loss, params)
            clipped_gradients, _ = tf.clip_by_global_norm(
                gradients, self.args.max_gradient_norm)
            self.optimizer = optimizer.apply_gradients(
                zip(clipped_gradients, params))

        # inference optimizer
        if len(self.inference_losses) > 0:
            self.inference_loss = sum(self.inference_losses)
            self.test_optimizer = tf.train.AdamOptimizer(
                self.args.test_learning_rate).minimize(self.inference_loss,
                                                       var_list=self.h_test)

    def run(self, dg):
        fetch = [self.optimizer, self.loss, *self.train_losses, self.acc,
                 self.ter, self.acc1, self.acc2]
        acc_loss = np.asarray([0.] * (len(fetch) - 1))
        count = 0
        for epoch in range(self.args.steps):
            X, [Y1, Y2] = dg.get_training_samples(self.args.batch_size)
            feed = {self.use_encoder: 1, self.add_noise: 1, self.X: X,
                    self.Y1: Y1, self.Y2: Y2, self.lr: self.args.lr}
            fetched = self.sess.run(fetch, feed_dict=feed)
            acc_loss += np.asarray(fetched[1:])
            count += 1
            if (epoch + 1) % self.args.log_steps == 0:
                result = acc_loss / count
                print(epoch + 1, *result)
                acc_loss = 0 * acc_loss
                count = 0

    def extraction(self, X):
        h = self.sess.run(self.h_train,
                          feed_dict={self.use_encoder: 1, self.X: X})
        return h

    def test_normal(self, data, name):
        X, [Y1, Y2], _ = data
        fetch = [self.loss, self.prediction_loss, self.acc, self.ter,
                 self.acc1, self.acc2, self.pred1, self.pred2, self.h1,
                 self.h2]
        feed = {self.use_encoder: 1, self.add_noise: 0, self.X: X, self.Y1: Y1,
                self.Y2: Y2}
        fetched = self.sess.run(fetch, feed_dict=feed)
        if len(name) > 0:
            self.cf_matrix([Y1, Y2], fetched[-4:-2], name)
            self.save_hidden([Y1, Y2], fetched[-2:], name + "_hidden")
        return fetched[:-4]

    def test(self, data, memory, name='test'):
        if self.args.inference_steps == 0:
            return self.test_normal(data, name)

        X, [Y1, Y2], [X1, X2] = data
        if not self.args.random_initialize:
            # run forward from input to hidden
            h = self.extraction(X)
            # keep hidden values as initial values.
            for variable, value in zip(self.h_test, h):
                variable.load(value, session=self.sess)

        # inference optimization
        if self.args.use_entropy_regularization_inference:
            inference_noise = 1
        else:
            inference_noise = 0
        fetch = [self.test_optimizer, self.inference_loss,
                 *self.inference_losses, self.acc, self.ter, self.acc1,
                 self.acc2]
        for epoch in range(self.args.inference_steps):
            feed = {self.use_encoder: 0, self.add_noise: inference_noise,
                    self.X: X, self.m1: memory[0], self.m2: memory[1],
                    self.Y1: Y1, self.Y2: Y2}
            fetched = self.sess.run(fetch, feed_dict=feed)
            if epoch % 1 == 0:
                print(epoch, *fetched[1:])

        # prediction
        pred_fetch = [self.loss, self.prediction_loss, self.acc, self.ter,
                      self.acc1, self.acc2, self.pred1, self.pred2, self.h1,
                      self.h2]
        pred_feed = {self.use_encoder: 0, self.add_noise: 0, self.X: X,
                     self.Y1: Y1, self.Y2: Y2}
        fetched = self.sess.run(pred_fetch, feed_dict=pred_feed)

        # visualization
        if self.args.use_reconstruction_train:
            visual_fetch = [self.X_hat, self.X1_hat, self.X2_hat]
            X_hat, X1_hat, X2_hat = self.sess.run(visual_fetch,
                                                  feed_dict=pred_feed)
            self.save(X, name + '_x_images')
            self.save(X_hat, name + '_x_hat_images')
            self.save(X1_hat, name + '_x1_hat_images')
            self.save(X2_hat, name + '_x2_hat_images')
            self.save(X1, name + '_x1_images')
            self.save(X2, name + '_x2_images')

        self.cf_matrix([Y1, Y2], fetched[-4:-2], name)
        self.save_hidden([Y1, Y2], fetched[-2:], name + "_hidden")
        return fetched[:-4]

    def save(self, X, name):
        save_images(self.args.experiment_id, X, name)

    def cf_matrix(self, Y, Y_hat, name):
        Y1, Y2 = Y
        Y1_hat, Y2_hat = Y_hat
        data = [Y1, Y2, Y1_hat, Y2_hat]
        save_text(self.args.experiment_id, data, name)

    def save_hidden(self, Y, H, name):
        Y1, Y2 = Y
        H1, H2 = H
        data = [Y1, Y2, H1, H2]
        save_hidden(self.args.experiment_id, data, name)


class MonthModel(MNISTModel):
    def _get_encoder(self, x):
        x = tf.reshape(x, [-1, self.input_size])
        x = tf.layers.dense(x, self.args.n_hidden_nodes * 3,
                            activation=tf.nn.relu)
        x = tf.layers.dense(x, self.args.n_hidden_nodes * 2,
                            activation=tf.nn.relu)
        x = tf.layers.dense(x, self.args.n_hidden_nodes)
        return x

    def _get_decoder(self, x, output_depth=1):
        x = tf.layers.dense(x, self.args.n_hidden_nodes * 2,
                            activation=tf.nn.relu)
        x = tf.layers.dense(x, self.args.n_hidden_nodes * 3,
                            activation=tf.nn.relu)
        x = tf.layers.dense(x, self.input_size)
        x = tf.reshape(x, [-1, *self.input_shape])
        return x


class ColorModel(MNISTModel):
    def _get_decoders(self, h1, h2):
        h = tf.concat([h1, h2], -1)
        X_hat = self._get_decoder(h, output_depth=3)
        return X_hat, X_hat, X_hat
