import tensorflow.keras.backend as K
import tensorflow as tf
import numpy as np
from tensorflow import keras
from tensorflow.keras.layers import Activation, Input, Dense, GlobalAveragePooling2D, BatchNormalization, Flatten

### Design Interval SoftMax Activation Function ###
@keras.saving.register_keras_serializable(package="my_package", name="IntSoftMax")
def IntSoftMax(inputs):
  # Extract number of classes
  Nc = int(inputs.shape[-1]/2)

  # Extract center and the radius
  center = inputs[:, :Nc]
  radius = inputs[:, Nc:]

  # Ensure the nonnegativity of radius
  radius_nonneg = tf.math.softplus(radius)

  # Compute upper and lower probabilities
  lo = K.exp(center-radius_nonneg) / (K.sum(K.exp(center), axis=-1, keepdims=True) - K.exp(center) + K.exp(center-radius_nonneg))
  hi = K.exp(center+radius_nonneg) / (K.sum(K.exp(center), axis=-1, keepdims=True) - K.exp(center) + K.exp(center+radius_nonneg))


  # Generata output
  output = tf.concat([lo, hi], axis=-1)

  return output



class CreNetModel(keras.Model):
    def __init__(self,
                 backbone,
                 classes,
                 input_shape,
                 delta,
                 weights,
                 task='multi',
                 *args, **kwargs):
        super().__init__(*args, **kwargs)


        self.backbone = backbone
        self.classes = classes
        self.input_size = input_shape
        self.delta = delta
        self.pre_weights = weights

        # Build the model
        self.model = self.build_crenet(self.creator)
        # self.model.summary()
        
        # Specify type of task
        self.task = task
        
        # Define Performance Monitoring
        self.total_loss_tracker = keras.metrics.Mean(name="Loss-T")
        self.upper_loss_tracker = keras.metrics.Mean(name="Loss-U")
        self.lower_loss_tracker = keras.metrics.Mean(name="Loss-L")
        self.val_upper_loss_tracker = keras.metrics.Mean(name="ValLoss-U")
        self.val_lower_loss_tracker = keras.metrics.Mean(name="ValLoss-L")

        if self.task == 'multi':
            self.upper_acc_tracker = keras.metrics.CategoricalAccuracy(name='Acc-U')
            self.lower_acc_tracker = keras.metrics.CategoricalAccuracy(name='Acc-L')
            self.val_upper_acc_tracker = keras.metrics.CategoricalAccuracy(name='ValAcc-U')
            self.val_lower_acc_tracker = keras.metrics.CategoricalAccuracy(name='ValAcc-L')            

        elif self.task == 'binary':
            self.upper_acc_tracker = tf.keras.metrics.BinaryAccuracy(name='Acc-U')
            self.lower_acc_tracker = tf.keras.metrics.BinaryAccuracy(name='Acc-L')
            self.val_upper_acc_tracker = keras.metrics.BinaryAccuracy(name='ValAcc-U')
            self.val_lower_acc_tracker = keras.metrics.BinaryAccuracy(name='ValAcc-L')
        else:
            print("Invalid Task Name. Try again...")
        
    def creator(self, inputs):
        
        if self.backbone=='RESNET50':
            base = tf.keras.applications.resnet50.ResNet50(include_top=False, weights=self.pre_weights, input_shape=(224, 224, 3), classes=self.classes)
        else:
            print("Invalid Model Type. Try again...")
            
        x = tf.keras.layers.UpSampling2D(size=(7, 7))(inputs)
        x = base(x)
        x = GlobalAveragePooling2D()(x)
        x = Flatten()(x)
        x = Dense(units=1024, activation='relu')(x)
        x = Dense(units=512, activation='relu')(x) 
        x = Dense(units=2*self.classes, activation=None)(x)
        x = BatchNormalization()(x)
        outputs = Activation(IntSoftMax)(x)

        model = keras.Model(inputs, outputs, name='CreNet_'+self.backbone)

        return model
        
    def build_crenet(self, creator):
        inputs = Input(self.input_size)
        return creator(inputs)

    def train_step(self, data, train_batch_num):

        inputs, labels = data

        with tf.GradientTape() as tape:
            preds = self.model(inputs, training=True)

            # Extract upper and lower probabilities
            preds_lo = preds[:, :labels.shape[-1]]
            preds_up = preds[:, labels.shape[-1]:]

            # Compute loss related to lower probabilities
            if self.task == 'multi':
                loss_lo = tf.keras.losses.CategoricalCrossentropy(
                    reduction=tf.keras.losses.Reduction.NONE)(labels, (preds_lo))

                # Select top delta * batch_size samples with highest loss for backward
                loss_lo_sort = tf.sort(loss_lo, direction='DESCENDING', axis=-1)

                bound_index = int(np.floor(self.delta*train_batch_num))-1
                bound_value = loss_lo_sort[bound_index]

                choose_index = tf.greater_equal(loss_lo, bound_value)
                choose_preds_lo = preds_lo[choose_index]
                choose_labels = labels[choose_index]

                loss_lo_mod = tf.reduce_mean(
                        tf.keras.losses.CategoricalCrossentropy(reduction=tf.keras.losses.Reduction.NONE)
                        (choose_labels, (choose_preds_lo)))

                loss_up = tf.reduce_mean(
                    tf.keras.losses.CategoricalCrossentropy(
                        reduction=tf.keras.losses.Reduction.NONE)(labels, (preds_up)))
            else:
                loss_lo = tf.keras.losses.BinaryCrossentropy(
                    reduction=tf.keras.losses.Reduction.NONE)(labels, (preds_lo))

                # Select top delta * batch_size samples with highest loss for backward
                loss_lo_sort = tf.sort(loss_lo, direction='DESCENDING', axis=-1)

                bound_index = int(np.floor(self.delta*train_batch_num))-1
                bound_value = loss_lo_sort[bound_index]

                choose_index = tf.greater_equal(loss_lo, bound_value)
                choose_preds_lo = preds_lo[choose_index]
                choose_labels = labels[choose_index]

                loss_lo_mod = tf.reduce_mean(
                        tf.keras.losses.BinaryCrossentropy(reduction=tf.keras.losses.Reduction.NONE)
                        (choose_labels, (choose_preds_lo)))

                loss_up = tf.reduce_mean(
                    tf.keras.losses.BinaryCrossentropy(
                        reduction=tf.keras.losses.Reduction.NONE)(labels, (preds_up)))


            loss_total = loss_lo_mod + loss_up

        grads = tape.gradient(loss_total, self.model.trainable_variables)
        self.optimizer.apply_gradients(zip(grads, self.model.trainable_variables))


        self.total_loss_tracker.update_state(loss_total)
        self.upper_loss_tracker.update_state(loss_up)
        self.lower_loss_tracker.update_state(loss_lo_mod)
        self.upper_acc_tracker.update_state(labels, preds_up)
        self.lower_acc_tracker.update_state(labels, preds_lo)

        return {"Loss-T": self.total_loss_tracker.result(), "Loss-U": self.upper_loss_tracker.result(),
                "Loss-L": self.lower_loss_tracker.result(),  'Acc-U': self.upper_acc_tracker.result(),
                'Acc-L': self.lower_acc_tracker.result()}


    def test_step(self, data):
        inputs, labels = data
        preds = self.model(inputs, training=False)

        # Extract upper and lower probabilities
        preds_lo = preds[:, :labels.shape[-1]]
        preds_up = preds[:, labels.shape[-1]:]

        # Compute the relavant loss using upper and lower probabilities
        if self.task == 'multi':
            loss_lo = tf.keras.losses.CategoricalCrossentropy(
                reduction=tf.keras.losses.Reduction.NONE)(labels, preds_lo)

            loss_up = tf.keras.losses.CategoricalCrossentropy(
                reduction=tf.keras.losses.Reduction.NONE)(labels, preds_up)
        else:
            loss_lo = tf.keras.losses.BinaryCrossentropy(
                reduction=tf.keras.losses.Reduction.NONE)(labels, preds_lo)

            loss_up = tf.keras.losses.BinaryCrossentropy(
                reduction=tf.keras.losses.Reduction.NONE)(labels, preds_up)
            
        self.val_upper_loss_tracker.update_state(loss_up)
        self.val_lower_loss_tracker.update_state(loss_lo)
        
        # Update validation accuracy
        self.val_upper_acc_tracker.update_state(labels, preds_up)
        self.val_lower_acc_tracker.update_state(labels, preds_lo)

        return {'ValAcc-U': self.val_upper_acc_tracker.result(),
                'ValAcc-L': self.val_lower_acc_tracker.result(),
                "ValLoss-U": self.val_upper_loss_tracker.result(),
                "ValLoss-L": self.val_lower_loss_tracker.result()
                }

    def save(self, file_name):
        return self.model.save(file_name)

    def predict(self, inputs):
        return self.model.predict(inputs)

    @property
    def metrics(self):
       return [self.total_loss_tracker, self.upper_loss_tracker, self.lower_loss_tracker,
               self.upper_acc_tracker, self.lower_acc_tracker, 
               self.val_upper_acc_tracker, self.val_lower_acc_tracker,
               self.val_upper_loss_tracker, self.val_lower_loss_tracker,
              ]
  
    def get_config(self):
        return {
                "backbone": self.backbone,
                "delta": self.delta,
                "input_shape": self.input_size,
                "classes": self.classes,
                "weights": self.pre_weights,
                "task": self.task,
                }