# lint as: python3
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import pickle
from absl import app
import tensorflow as tf
from tensorflow import keras
import tensorflow.keras.backend as K
import tensorflow.keras.layers as layers
import tensorflow.keras as keras
from tensorflow.keras.layers import Input
# from keras.layers import Lambda
from tensorflow.keras.models import Sequential, Model
from tensorflow.keras.optimizers import Adam, SGD
import matplotlib
# matplotlib.use('GTK3Agg')
from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas
from matplotlib.figure import Figure
import matplotlib.pyplot as plt
import numpy as np
from numpy.random import seed
import tensorflow as tf
from PIL import Image
from matplotlib import cm
#seed(0)
#tf.random.set_seed(0)


class InceptionV3Energy(keras.Model):
    def compile(self, optimizer, my_loss):
        super().compile(optimizer)
        self.my_loss = my_loss
        self.optimizer=optimizer
        self.compiled_metrics = keras.metrics.CategoricalAccuracy()

    def train_step(self, data):
        # Unpack the data. Its structure depends on your model and
        # on what you pass to `fit()`.

        # print(data)
        x, y_IN = data
        x_IN = x[0]
        x_OUT = x[1]

        with tf.GradientTape() as tape:
            logits_IN = self(x_IN, training=True)  # Forward pass
            logits_OUT = self(x_OUT, training=True)  # Forward pass
            # Compute the loss value
            # (the loss function is configured in `compile()`)
            loss = self.my_loss(y_IN, logits_IN, logits_OUT)

        # Compute gradients
        trainable_vars = self.trainable_variables
        gradients = tape.gradient(loss, trainable_vars)
        # Update weights
        self.optimizer.apply_gradients(zip(gradients, trainable_vars))
        # Update metrics (includes the metric that tracks the loss)
        self.compiled_metrics.update_state(y_IN, logits_IN)

        # return {m.name: m.result() for m in self.metrics}
        return {"loss": loss, "acc": self.compiled_metrics.result()}

    """
    @property
    def metrics(self):
        # We list our `Metric` objects here so that `reset_states()` can be
        # called automatically at the start of each epoch
        # or at the start of `evaluate()`.
        # If you don't implement this property, you have to call
        # `reset_states()` yourself at the time of your choosing.
        return [self.compiled_metrics]
    """

def prepare_InceptionV3(modelpath, input_size=(224,224), logits=True, pretrain=False, return_model=False):

    # tf.random.set_seed(1)
    input_tensor = tf.keras.Input(shape=(input_size[0], input_size[1], 3))
    resized_images = layers.Lambda(lambda image: tf.image.resize(image, (224, 224)))(input_tensor)
    base_model = tf.keras.applications.InceptionV3(weights='imagenet',
                                                   include_top=False,
                                                   input_tensor=resized_images,
                                                   pooling='max')
    for layer in base_model.layers:
        layer.trainable = False
    output_from_model = base_model.layers[-2].output #mixed10
    global_pool = base_model.layers[-1]
    global_pool_out = global_pool(output_from_model)

    flatten_out = layers.Flatten()(global_pool_out)
    fc1 = layers.Dense(units=256, activation='relu',
                        # kernel_initializer=tf.keras.initializer.he_normal(),
                        kernel_regularizer=tf.keras.regularizers.l2())
    fc1_out = fc1(flatten_out)
    dropout = layers.Dropout(0.5)
    dropout_out = dropout(fc1_out)
    fc2 = layers.Dense(units=50, # 50 classes
                      activation=None if logits else 'softmax',
                      kernel_regularizer=tf.keras.regularizers.l2())
    output_tensor = fc2(dropout_out)  # NOTE: logits NOT softmax output!!!!!

    if return_model:
        model = tf.keras.models.Model(inputs=input_tensor, outputs=output_tensor)
        model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=1e-4),
                    loss='categorical_crossentropy', metrics=['accuracy'])
        if pretrain:
            model.load_weights(modelpath, by_name=True)

        return model

    else:
        return input_tensor, output_tensor




class TopicModelMahalanobis(keras.Model):
    def compile(self, feature_model, topic_vector_n, n_concept, optimizer, my_loss, run_eagerly=False):
        super().compile(optimizer, run_eagerly=run_eagerly)
        self.feature_model = feature_model
        self.topic_vector_n = topic_vector_n
        self.n_concept = n_concept
        self.my_loss = my_loss
        self.optimizer=optimizer
        self.compiled_metrics = keras.metrics.CategoricalAccuracy()

    def train_step(self, data):

        print("----Start of step: %d" % (self.step_counter,))
        self.step_counter += 1

        # Unpack the data. Its structure depends on your model and
        # on what you pass to `fit()`.
        # print(data)
        x, y_in = data
        x_in = x[0]
        
        f_in = self.feature_model(x_in)
        f_in = tf.make_ndarray(tf.make_tensor_proto(f_in))
        f_in_n = f_in/(np.linalg.norm(f_in,axis=3,keepdims=True)+1e-9)
        #f_in_n = tf.math.l2_normalize(f_in, axis=3, epsilon=1e-9, name='f_in_n')
        x_out = x[1]
        f_out = self.feature_model(x_out)
        f_out = tf.make_ndarray(tf.make_tensor_proto(f_out))
        f_out_n = f_out/(np.linalg.norm(f_out,axis=3,keepdims=True)+1e-9)
        #f_out_n = tf.math.l2_normalize(f_out, axis=3, epsilon=1e-9, name='f_out_n')
        
        n_concept = self.n_concept
        topic_vector_n = self.topic_vector_n
        
        with tf.GradientTape() as tape:
            logits_in = self(f_in, training=True)  # Forward pass
            #topic_vector_n = self.trainable_variables[0] #self.topic_vector_n
            #print("printing self.trainable_variables[0].....")
            #print(topic_vector_n)
            #print("printing self.topic_vector_n......")
            #print(self.topic_vector_n)

            topic_prob_n_in = K.dot(f_in_n, topic_vector_n) #np.matmul
            topic_prob_n_out = K.dot(f_out_n, topic_vector_n)
            # Compute the loss value
            # (the loss function is configured in `compile()`)
            loss = self.my_loss(y_in, logits_in, topic_prob_n_in, topic_prob_n_out, topic_vector_n, n_concept)

        # Compute gradients
        trainable_vars = self.trainable_variables
        #print(trainable_vars)

        gradients = tape.gradient(loss, trainable_vars)
        print(gradients)

        # Update weights
        self.optimizer.apply_gradients(zip(gradients, trainable_vars))
        # Update metrics (includes the metric that tracks the loss)
        self.compiled_metrics.update_state(y_in, logits_in)
        
        print(np.argmax(logits_in, axis=1))
        print(np.argmax(y_in, axis=1))
        print("accuracy: {}".format(np.sum(y_in == logits_in)/len(y_in)))

        # return {m.name: m.result() for m in self.metrics}
        return {"loss": loss, "acc": self.compiled_metrics.result()}
