# coding=utf-8
# Copyright 2023.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""

Uncertainty Aware Information Bottleneck on synthetic regression task.


"""


import numpy as np
from absl import app
from absl import flags

import tensorflow.compat.v2 as tf
import edward2 as ed
import matplotlib.pyplot as plt

from layers import DenseUAIB

FLAGS = flags.FLAGS

flags.DEFINE_integer('epochs', 1500, 'Number of epochs.')
flags.DEFINE_integer('num_train_examples', 20, 'Number of training datapoints.')
flags.DEFINE_float('learning_rate', 0.01, 'Learning rate.')

flags.DEFINE_float("beta", 1.0, "Uncertainty Aware Information Bottleneck lagrange multiplier for regularization.")
flags.DEFINE_float("uaib_tau",5.0," Temperature of the variational marginal distribution.",)
flags.DEFINE_integer("uaib_dim", 8, "Bottleneck dimension")
flags.DEFINE_integer('example', 1, 'Example to be tested: 1. datapoints in the middle. 2: datapoints at the edges')
flags.DEFINE_integer("codebook_size", 1 , "Codebook size. ")


flags.DEFINE_bool('verbose', False, 'Print numerical details.')

SMALL_SIZE = 8
MEDIUM_SIZE = 10
BIGGER_SIZE = 12

plt.rc('font', size=BIGGER_SIZE)          # controls default text sizes
plt.rc('axes', titlesize=BIGGER_SIZE)     # fontsize of the axes title
plt.rc('axes', labelsize=MEDIUM_SIZE)    # fontsize of the x and y labels
plt.rc('xtick', labelsize=MEDIUM_SIZE)    # fontsize of the tick labels
plt.rc('ytick', labelsize=MEDIUM_SIZE)    # fontsize of the tick labels
plt.rc('legend', fontsize=BIGGER_SIZE)    # legend fontsize
plt.rc('figure', titlesize=BIGGER_SIZE)  # fontsize of the figure title



def create_dataset():
    """
    Create dataset for the linear regression example.
    """
    
    
    
    if FLAGS.example==1:
        x1 = np.random.uniform(-4, -0.0, size=FLAGS.num_train_examples//2).reshape((-1, 1))
        x2 = np.random.uniform(0.0, 4, size=FLAGS.num_train_examples//2).reshape((-1, 1))
    else:
        x1 = np.random.uniform(-5, -2.0, size=FLAGS.num_train_examples//2).reshape((-1, 1))
        x2 = np.random.uniform(2.0, 5, size=FLAGS.num_train_examples//2).reshape((-1, 1))

    x=np.concatenate([x1,x2])
    

    noise = np.random.normal(0, 9, size=FLAGS.num_train_examples).reshape((-1, 1))
    y = x ** 3 + noise

    
    
    x_ = np.linspace(-5, 5).reshape((-1, 1))
    y_ = x_ ** 3
    
    x= np.float32(x)
    y= np.float32(y)
    
    x_= np.float32(x_)
    y_= np.float32(y_)

    return y, x,  x_, y_
    

def multilayer_perceptron():
    """
    Create model.
    """
    
    inputs = tf.keras.layers.Input(shape=(1)) ## change this  remove one dimension keep only the samples
    
    hidden_1 = tf.keras.layers.Dense(units=100,activation='elu')(inputs)
                                             
    hidden_2 = tf.keras.layers.Dense(units=100,activation='elu')(hidden_1)
                                                         
    uaib_output= DenseUAIB( uaib_dim=FLAGS.uaib_dim,
                            codebook_size=FLAGS.codebook_size,
                            uaib_tau=FLAGS.uaib_tau,
                            name="dense_uaib", 
                            activation=None, 
                            momentum=0.0, 
                            )(hidden_2)
                                             
    latent_features, uncertainty=tf.split(uaib_output, [FLAGS.uaib_dim, 1], axis=-1)
                                             
    output = tf.keras.layers.Dense(units=1,activation=None,)(latent_features)
    

    return tf.keras.Model(inputs=inputs, outputs=tf.concat([output, uncertainty], axis=-1))
    
def main(argv):
    
    y_train, x_train, x_test, y_test = create_dataset()
        
    model = multilayer_perceptron()
    
    print(model.layers)
    
    optimizer_c = tf.keras.optimizers.Adam( FLAGS.learning_rate)
        
    optimizer = tf.keras.optimizers.Adam(0.001)
    
    @tf.function
    def cluster_step(inputs):
        """ 
        """
        def centroid_step_fn(inputs):
            """
            """
            with tf.GradientTape() as tape:
                outputs=model(inputs, training=True)
                
                _, uncertainty= tf.split(outputs, 2, axis=-1)
        
                loss = tf.reduce_mean(uncertainty)

                grads = tape.gradient(loss, model.trainable_variables)
                # train centroids
                grads_and_vars = []
                for grad, var in zip(grads, model.trainable_variables):
                      if ('centroid' in var.name ):
                        grads_and_vars.append((grad ,  var))
                optimizer_c.apply_gradients(grads_and_vars)
        
        def centroid_probs_step_fn(inputs):
            """
            """
            model(inputs, training=True)
            
        initialized_centroids_op =   model.get_layer('dense_uaib') .initialized.assign(tf.constant(True, dtype=tf.bool))
        model.get_layer('dense_uaib').add_update(initialized_centroids_op)
            
        # M-step: train centroids
        model.get_layer('dense_uaib').reset_centroids()
        centroid_step_fn(inputs)
        model.get_layer('dense_uaib').set_centroids()
       
        # M-step: compute prior centroid probabilities 
        model.get_layer('dense_uaib').reset_centroid_probs()
        centroid_probs_step_fn(inputs)
        model.get_layer('dense_uaib').set_centroid_probs()


    @tf.function
    def train_step(inputs, labels):

        with tf.GradientTape() as tape:
                outputs = model(inputs, training=True)

                predictions, uncertainty= tf.split(outputs, 2, axis=-1)
                
                distribution=tf.keras.layers.Lambda(lambda x: ed.Normal(loc=x, scale=1.0))(predictions)
                
    
                
                negative_log_likelihood = -distribution.distribution.log_prob(labels)

                l2_loss = sum([l for l in model.losses  if not "cluster_distances" in l.name])
                
                if FLAGS.verbose:
                    tf.print(negative_log_likelihood, summarize=-1)
                
                loss = negative_log_likelihood + l2_loss + FLAGS.beta * tf.reduce_mean(uncertainty)
             

        grads = tape.gradient(loss, model.trainable_variables)
            
        # train encoder/ decoder
        grads_and_vars = []
        for grad, var in zip(grads, model.trainable_variables):
                if ('centroid' not in var.name ):
                    grads_and_vars.append((grad ,  var))
        optimizer.apply_gradients(grads_and_vars)
      

    for epoch in range(FLAGS.epochs):
        
        if FLAGS.verbose:
            print('epoch')
            print(epoch)

        train_step(x_train, y_train)
        
        cluster_step(x_train)

       
    output=model.predict(x_test)
    y_pred, uncertainty= tf.split(output, 2, axis=-1)
    
    y_pred_low=np.squeeze(y_pred - 2* uncertainty)
    y_pred_high=np.squeeze(y_pred + 2*uncertainty)

    if FLAGS.verbose:
        print(uncertainty)


    
    plt.fill_between(np.squeeze(x_test), y_pred_low, y_pred_high, color='coral', alpha=.5, label='Uncertainty')
    
        
    plt.plot(x_test, y_pred, c='royalblue', label='Prediction', linewidth=2)    
        
    plt.scatter(x_train, y_train, c='navy', label='Train Datapoint')
    plt.plot(x_test, y_test, c='grey', label='Ground Truth',  linewidth=2)
     
    
    if FLAGS.example==1:
        plt.legend(loc='upper center')
        plt.legend().set_visible(True)
    else:
        plt.legend().set_visible(False)

        
    plt.tight_layout()
    plt.show()

if __name__ == '__main__':
    app.run(main)
