from tensorflow.keras.layers import *
from tensorflow.keras.models import Model
import tensorflow.keras.backend as K
import tensorflow as tf


class CTloss(Layer):
    """ Identity transform layer that adds CT loss
        to the final model loss.
        """
    def __init__(self, beta, *args, **kwargs):
        self.is_placeholder = True
        self.beta = beta
        super(CTloss, self).__init__(**kwargs)

    def call(self, inputs, **kwargs):
        f_inbn, f_innobn = inputs
        f_inbn_n = tf.nn.l2_normalize(f_inbn, axis=-1)
        f_innobn_n = tf.nn.l2_normalize(f_innobn, axis=-1)
        mean_l = K.abs(K.mean(f_inbn_n) - K.mean(f_innobn_n))
        var_l = K.abs(K.var(f_inbn_n) - K.var(f_innobn_n))
        mse_l = K.mean(tf.keras.losses.mean_squared_error(f_inbn, f_innobn))

        self.add_loss(self.beta * (mean_l + var_l + mse_l), inputs=inputs)

        return f_inbn

def base_model_mnist(input_shape, num_cls, use_bn=True, ct=False):
    X_input = Input(input_shape)
    X = Conv2D(16, (3, 3), strides=(1, 1), name='conv0')(X_input)
    if use_bn:
        X = BatchNormalization()(X)
    X = Activation('relu')(X)
    X = Conv2D(16, (3, 3), strides=(1, 1), name='conv1')(X)
    if use_bn:
        X = BatchNormalization()(X)
    X = Activation('relu')(X)
    X = MaxPooling2D((2, 2), strides=(2, 2), name='mp1')(X)

    X = Conv2D(32, (3, 3), strides=(1, 1), name='conv2')(X)
    if use_bn:
        X = BatchNormalization()(X)
    X = Activation('relu')(X)
    X = Conv2D(32, (3, 3), strides=(1, 1), name='conv3')(X)
    if use_bn:
        X = BatchNormalization()(X)
    X = Activation('relu', name='ac3')(X)
    X = MaxPooling2D((2, 2), strides=(2, 2), name='mp2')(X)

    X = Flatten(name='flat')(X)
    X = Dense(256) (X)
    if use_bn:
        X = BatchNormalization()(X)
    X = Activation('relu', name='fvec')(X)

    if ct:
        return Model(inputs=X_input, outputs=X, name='MNIST_Model')

    else:
        X = Dense(num_cls, activation='softmax', name='predictions')(X)
        model = Model(inputs=X_input, outputs=X, name='MNIST_Model')
        return model


def ct_model_mnist(input_shape, num_cls):
    in_ = Input(input_shape)

    teacher_model = base_model_mnist(input_shape, num_cls, use_bn=False, ct=False)
    teacher_model.load_weights('./weights_ct/mnist_teacher.h5')

    # Freeze all the layers of the teacher model
    for layer in teacher_model.layers[:]:
        layer.trainable = False

    model_base = Model(teacher_model.input, teacher_model.get_layer('fvec').output)
    f_teacher = model_base(in_)
    student = base_model_mnist(input_shape, num_cls, use_bn=True, ct=True)
    f_student = student(in_)
    f_student = CTloss(500)([f_student, f_teacher])
    pred = Dense(2, activation="softmax")(f_student)

    return Model(in_, pred)


