import tensorflow.keras.backend as K
import tensorflow as tf
import numpy as np
from tensorflow import keras
from tensorflow.keras.layers import Activation, Input, Dense, GlobalAveragePooling2D, BatchNormalization, Flatten


### Design Interval SoftMax Activation Function ###
@keras.saving.register_keras_serializable(package="my_package", name="IntSoftMax")
def IntSoftMax(inputs):
  # Extract number of classes
  Nc = int(inputs.shape[-1]/2)

  # Extract center and the radius
  center = inputs[:, :Nc]
  radius = inputs[:, Nc:]

  # Ensure the nonnegativity of radius
  radius_nonneg = tf.math.softplus(radius)

  # Compute upper and lower probabilities
  lo = K.exp(center-radius_nonneg) / (K.sum(K.exp(center), axis=-1, keepdims=True) - K.exp(center) + K.exp(center-radius_nonneg))
  hi = K.exp(center+radius_nonneg) / (K.sum(K.exp(center), axis=-1, keepdims=True) - K.exp(center) + K.exp(center+radius_nonneg))


  # Generata output
  output = tf.concat([lo, hi], axis=-1)

  return output


def crenet_res50(input_shape, num_classes):
    inputs = Input(input_shape)
    x = tf.keras.applications.resnet50.ResNet50(include_top=False, weights=None, input_shape=(32, 32, 3), classes=num_classes)(inputs)
    x = GlobalAveragePooling2D()(x)
    # x = Flatten()(x)
    x = Dense(units=2*num_classes, activation=None)(x)
    x = BatchNormalization()(x)
    outputs = Activation(IntSoftMax)(x)

    model = keras.Model(inputs, outputs, name='CreNet_RES50')  
    
    return model
    