import tensorflow as tf
from tensorflow import keras
from tensorflow.python.keras import backend as K
from tensorflow.python.keras.layers import Conv2D, Conv3D, Input, AveragePooling2D, \
    multiply, Dense, Dropout, Flatten, AveragePooling3D, LSTM, Concatenate, TimeDistributed, Lambda, Attention, BatchNormalization, GRU, Bidirectional
from tensorflow.python.keras.models import Model
from tensorflow.keras import regularizers
from tensorflow.python.ops.gen_batch_ops import Batch
from tensorflow.keras.losses import mean_squared_error

#%%

def dice_coef(y_true, y_pred, smooth=1):
    """
    Source: https://gist.github.com/wassname/7793e2058c5c9dacb5212c0ac0b18a8a
    Dice = (2*|X & Y|)/ (|X|+ |Y|)
         =  2*sum(|A*B|)/(sum(A^2)+sum(B^2))
    ref: https://arxiv.org/pdf/1606.04797v1.pdf
    """
    intersection = K.sum(K.abs(y_true * y_pred), axis=-1)
    return (2. * intersection + smooth) / (K.sum(K.square(y_true),-1) + K.sum(K.square(y_pred),-1) + smooth)

def dice_coef_loss(y_true, y_pred):
    return 1-dice_coef(y_true, y_pred)

def diff(y):
    return y[:, 1:] - y[:, :-1]
def first_and_second_derivative_loss(y_true, y_pred):
    loss_weight = 0.5
    # calculate MSE
    squared_difference = tf.square(y_true - y_pred)
    mse = tf.reduce_mean(squared_difference, axis=-1)
    # calculate 2nd derivative of true and predicted waveform
    # y_true_diff = tf.experimental.numpy.diff(y_true)
    # y_pred_diff = tf.experimental.numpy.diff(y_pred)
    y_true_diff = diff(y_true)
    y_pred_diff = diff(y_pred)
    # remove mean from 2nd derivative 
    y_true_diff_mean = tf.expand_dims(tf.reduce_mean(y_true_diff, axis=-1), axis=-1)
    y_true_diff = tf.subtract(y_true_diff, y_true_diff_mean)
    y_pred_diff_mean = tf.expand_dims(tf.reduce_mean(y_pred_diff, axis=-1), axis=-1)
    y_pred_diff = tf.subtract(y_pred_diff, y_pred_diff_mean)
    # calculate MSE between true/predicted 2nd derivative
    squared_difference_diff = tf.square(y_true_diff - y_pred_diff)
    mse_diff = tf.reduce_mean(squared_difference_diff, axis=-1)
    return loss_weight*mse + (1-loss_weight)*mse_diff  # Note the `axis=-1`

def first_derivative_loss(y_true, y_pred):
    # calculate MSE
    squared_difference = tf.square(y_true - y_pred)
    mse = tf.reduce_mean(squared_difference, axis=-1)
    return mse

def second_derivative_loss(y_true, y_pred):
    y_true_diff = diff(y_true)
    y_pred_diff = diff(y_pred)
    # remove mean from 2nd derivative 
    y_true_diff_mean = tf.expand_dims(tf.reduce_mean(y_true_diff, axis=-1), axis=-1)
    y_true_diff = tf.subtract(y_true_diff, y_true_diff_mean)
    y_pred_diff_mean = tf.expand_dims(tf.reduce_mean(y_pred_diff, axis=-1), axis=-1)
    y_pred_diff = tf.subtract(y_pred_diff, y_pred_diff_mean)
    # calculate MSE between true/predicted 2nd derivative
    squared_difference_diff = tf.square(y_true_diff - y_pred_diff)
    mse_diff = tf.reduce_mean(squared_difference_diff, axis=-1)
    return mse_diff

def second_derivative_peak_loss(y_true, y_pred):
    # use mask to get peak loss
    true_peak_values = y_true[:, :, 0] * y_true[:, :, 1]
    # # # multiply pred ABP with mask to get pred sys/dias BP values
    pred_peak_values = y_pred[:, :, 0] * y_true[:, :, 1]
    peak_mse = mean_squared_error(true_peak_values, pred_peak_values)
    mse = mean_squared_error(y_true[:, :, 0], y_pred[:, :, 0])
    return mse + peak_mse

class Attention_mask(tf.keras.layers.Layer):
    def call(self, x):
        xsum = K.sum(x, axis=1, keepdims=True)
        xsum = K.sum(xsum, axis=2, keepdims=True)
        xshape = K.int_shape(x)
        return x / xsum * xshape[1] * xshape[2] * 0.5

    def get_config(self):
        config = super(Attention_mask, self).get_config()
        return config

# %%

def DeepPhy(nb_filters1, nb_filters2, input_shape, kernel_size=(3, 3), dropout_rate1=0.25, dropout_rate2=0.5,
            pool_size=(2, 2), nb_dense=128):

    diff_input = Input(shape=input_shape)
    rawf_input = Input(shape=input_shape)

    d1 = Conv2D(nb_filters1, kernel_size, padding='same', activation='tanh')(diff_input)
    d2 = Conv2D(nb_filters1, kernel_size, activation='tanh')(d1)

    r1 = Conv2D(nb_filters1, kernel_size, padding='same', activation='tanh')(rawf_input)
    r2 = Conv2D(nb_filters1, kernel_size, activation='tanh')(r1)

    g1 = Conv2D(1, (1, 1), padding='same', activation='sigmoid')(r2)
    g1 = Attention_mask()(g1)
    gated1 = multiply([d2, g1])
    print(gated1.shape)

    d3 = AveragePooling2D(pool_size)(gated1)
    d4 = Dropout(dropout_rate1)(d3)

    r3 = AveragePooling2D(pool_size)(r2)
    r4 = Dropout(dropout_rate1)(r3)

    d5 = Conv2D(nb_filters2, kernel_size, padding='same', activation='tanh')(d4)
    d6 = Conv2D(nb_filters2, kernel_size, activation='tanh')(d5)

    r5 = Conv2D(nb_filters2, kernel_size, padding='same', activation='tanh')(r4)
    r6 = Conv2D(nb_filters2, kernel_size, activation='tanh')(r5)

    g2 = Conv2D(1, (1, 1), padding='same', activation='sigmoid')(r6)
    g2 = Attention_mask()(g2)
    gated2 = multiply([d6, g2])

    d7 = AveragePooling2D(pool_size)(gated2)
    d8 = Dropout(dropout_rate1)(d7)

    d9 = Flatten()(d8)
    d10 = Dense(nb_dense, activation='tanh')(d9)
    d11 = Dropout(dropout_rate2)(d10)
    out = Dense(1)(d11)
    model = Model(inputs=[diff_input, rawf_input], outputs=out)
    return model


#%% 2DCNN-MT
def DeepPhys_2DCNN_MT(nb_filters1, nb_filters2, input_shape, kernel_size=(3, 3), dropout_rate1=0.25, dropout_rate2=0.5,
            pool_size=(2, 2), nb_dense=128, use_dataloader=False):

    diff_input = Input(shape=input_shape)
    rawf_input = Input(shape=input_shape)

    d1 = Conv2D(nb_filters1, kernel_size, padding='same', activation='tanh')(diff_input)
    d2 = Conv2D(nb_filters1, kernel_size, activation='tanh')(d1)

    r1 = Conv2D(nb_filters1, kernel_size, padding='same', activation='tanh')(rawf_input)
    r2 = Conv2D(nb_filters1, kernel_size, activation='tanh')(r1)

    g1 = Conv2D(1, (1, 1), padding='same', activation='sigmoid')(r2)
    g1 = Attention_mask()(g1)
    gated1 = multiply([d2, g1])

    d3 = AveragePooling2D(pool_size)(gated1)
    d4 = Dropout(dropout_rate1)(d3)

    r3 = AveragePooling2D(pool_size)(r2)
    r4 = Dropout(dropout_rate1)(r3)

    d5 = Conv2D(nb_filters2, kernel_size, padding='same', activation='tanh')(d4)
    d6 = Conv2D(nb_filters2, kernel_size, activation='tanh')(d5)

    r5 = Conv2D(nb_filters2, kernel_size, padding='same', activation='tanh')(r4)
    r6 = Conv2D(nb_filters2, kernel_size, activation='tanh')(r5)

    g2 = Conv2D(1, (1, 1), padding='same', activation='sigmoid')(r6)
    g2 = Attention_mask()(g2)
    gated2 = multiply([d6, g2])

    d7 = AveragePooling2D(pool_size)(gated2)
    d8 = Dropout(dropout_rate1)(d7)

    d9 = Flatten()(d8)
    d10_y = Dense(nb_dense, activation='tanh')(d9)
    d11_y = Dropout(dropout_rate2)(d10_y)
    out_y = Dense(1, name='pulse')(d11_y)

    d10_r = Dense(nb_dense, activation='tanh')(d9)
    d11_r = Dropout(dropout_rate2)(d10_r)
    out_r = Dense(1, name='resp')(d11_r)

    model = Model(inputs=[diff_input, rawf_input], outputs=[out_y, out_r])
    return model


#%%

class Attention_Split(tf.keras.layers.Layer):
    def call(self, x):
        averaged_x = K.mean(x, axis=0)
        averaged_x = tf.squeeze(averaged_x)
        x0, x1 = tf.split(averaged_x, 2, axis=0)
        x0_0, x0_1 = tf.split(x0, 2, axis=1)
        x1_0, x1_1 = tf.split(x1, 2, axis=1)
        x0_0 = tf.reduce_mean(x0_0) # 17 x 17
        x0_1 = tf.reduce_mean(x0_1)
        x1_0 = tf.reduce_mean(x1_0)
        x1_1 = tf.reduce_mean(x1_1)
        # all = tf.tensor([x0_0, x0_1, x1_0, x1_1])
        return x0_0, x0_1, x1_0, x1_1

# at_split = Attention_Split()
# test_tensor = tf.random.normal([20, 36, 36, 1])
# a, b, c, d = at_split(test_tensor)
# print('a', a)
# print('b', b)
# print('c', c)
# print('d', d)

#%%

class TSM(tf.keras.layers.Layer):
    def call(self, x, n_frame, fold_div=3):
        nt, h, w, c = x.shape
        x = K.reshape(x, (-1, n_frame, h, w, c))
        fold = c // fold_div
        last_fold = c - (fold_div-1)*fold
        out1, out2, out3 = tf.split(x, [fold, fold, last_fold], axis=-1)

        # Shift left
        padding_1 = tf.zeros_like(out1)
        padding_1 = padding_1[:,-1,:,:,:]
        padding_1 = tf.expand_dims(padding_1, 1)
        _, out1 = tf.split(out1, [1, n_frame-1], axis=1)
        out1 = tf.concat([out1, padding_1], axis=1)

        # Shift right
        padding_2 = tf.zeros_like(out2)
        padding_2 = padding_2[:,0,:,:, :]
        padding_2 = tf.expand_dims(padding_2, 1)
        out2, _ = tf.split(out2, [n_frame-1,1], axis=1)
        out2 = tf.concat([padding_2, out2], axis=1)

        out = tf.concat([out1, out2, out3], axis=-1)
        out = K.reshape(out, (-1, h, w, c))

        return out

    def get_config(self):
        config = super(TSM, self).get_config()
        return config


def TSM_Cov2D(x, n_frame, nb_filters=128, kernel_size=(3, 3), activation='tanh', padding='same'):
    x = TSM()(x, n_frame)
    x = Conv2D(nb_filters, kernel_size, padding=padding, activation=activation)(x)
    return x

#%% CONV2D + TSM

def TS_CAN(n_frame, nb_filters1, nb_filters2, input_shape, kernel_size=(3, 3), dropout_rate1=0.25, dropout_rate2=0.5,
            pool_size=(2, 2), nb_dense=128):

    diff_input = Input(shape=input_shape)
    rawf_input = Input(shape=input_shape)

    d1 = TSM_Cov2D(diff_input, n_frame, nb_filters1, kernel_size, padding='same', activation='tanh')
    d1_plus = TSM_Cov2D(d1, n_frame, nb_filters1, kernel_size, padding='same', activation='tanh')
    d2 = TSM_Cov2D(d1_plus, n_frame, nb_filters1, kernel_size, padding='valid', activation='tanh')
    d2_plus = TSM_Cov2D(d2, n_frame, nb_filters1, kernel_size, padding='same', activation='tanh')

    r1 = Conv2D(nb_filters1, kernel_size, padding='same', activation='tanh')(rawf_input)
    r1_plus = Conv2D(nb_filters1, kernel_size, padding='same', activation='tanh')(r1)
    r2 = Conv2D(nb_filters1, kernel_size, activation='tanh')(r1_plus)
    r2_plus = Conv2D(nb_filters1, kernel_size, padding='same', activation='tanh')(r2)

    g1 = Conv2D(1, (1, 1), padding='same', activation='sigmoid')(r2_plus)
    g1 = Attention_mask()(g1)
    gated1 = multiply([d2_plus, g1])

    d3 = AveragePooling2D(pool_size)(gated1)
    d4 = Dropout(dropout_rate1)(d3)

    r3 = AveragePooling2D(pool_size)(r2)
    r4 = Dropout(dropout_rate1)(r3)

    d5 = TSM_Cov2D(d4, n_frame, nb_filters2, kernel_size, padding='same', activation='tanh')
    d5_plus = TSM_Cov2D(d5, n_frame, nb_filters2, kernel_size, padding='same', activation='tanh')
    d6 = TSM_Cov2D(d5_plus, n_frame, nb_filters2, kernel_size, padding='valid', activation='tanh')
    d6_plus = TSM_Cov2D(d6, n_frame, nb_filters2, kernel_size, padding='same', activation='tanh')

    r5 = Conv2D(nb_filters2, kernel_size, padding='same', activation='tanh')(r4)
    r5_plus = Conv2D(nb_filters2, kernel_size, padding='same', activation='tanh')(r5)
    r6 = Conv2D(nb_filters2, kernel_size, activation='tanh')(r5_plus)
    r6_plus = Conv2D(nb_filters2, kernel_size, padding='same', activation='tanh')(r6)

    g2 = Conv2D(1, (1, 1), padding='same', activation='sigmoid')(r6_plus)
    g2 = Attention_mask()(g2)
    gated2 = multiply([d6_plus, g2])

    d7 = AveragePooling2D(pool_size)(gated2)
    d8 = Dropout(dropout_rate1)(d7)

    d9 = Flatten()(d8)
    d10 = Dense(nb_dense, activation='tanh')(d9)
    d11 = Dropout(dropout_rate2)(d10)
    out = Dense(1)(d11)
    model = Model(inputs=[diff_input, rawf_input], outputs=out)
    return model




#%% TSM Multi-tasking

def MTTS_CAN(n_frame, nb_filters1, nb_filters2, input_shape, kernel_size=(3, 3), dropout_rate1=0.25,
                           dropout_rate2=0.5, pool_size=(2, 2), nb_dense=128):

    diff_input = Input(shape=input_shape)
    rawf_input = Input(shape=input_shape)

    d1 = TSM_Cov2D(diff_input, n_frame, nb_filters1, kernel_size, padding='same', activation='tanh')
    d2 = TSM_Cov2D(d1, n_frame, nb_filters1, kernel_size, padding='valid', activation='tanh')

    r1 = Conv2D(nb_filters1, kernel_size, padding='same', activation='tanh')(rawf_input)
    r2 = Conv2D(nb_filters1, kernel_size, activation='tanh')(r1)

    g1 = Conv2D(1, (1, 1), padding='same', activation='sigmoid')(r2)
    g1 = Attention_mask()(g1)
    gated1 = multiply([d2, g1])

    d3 = AveragePooling2D(pool_size)(gated1)
    d4 = Dropout(dropout_rate1)(d3)

    r3 = AveragePooling2D(pool_size)(r2)
    r4 = Dropout(dropout_rate1)(r3)

    d5 = TSM_Cov2D(d4, n_frame, nb_filters2, kernel_size, padding='same', activation='tanh')
    d6 = TSM_Cov2D(d5, n_frame, nb_filters2, kernel_size, padding='valid', activation='tanh')

    r5 = Conv2D(nb_filters2, kernel_size, padding='same', activation='tanh')(r4)
    r6 = Conv2D(nb_filters2, kernel_size, activation='tanh')(r5)

    g2 = Conv2D(1, (1, 1), padding='same', activation='sigmoid')(r6)
    g2 = Attention_mask()(g2)
    gated2 = multiply([d6, g2])

    d7 = AveragePooling2D(pool_size)(gated2)
    d8 = Dropout(dropout_rate1)(d7)

    d9 = Flatten()(d8)

    d10_y = Dense(nb_dense, activation='tanh')(d9)
    d11_y = Dropout(dropout_rate2)(d10_y)
    out_y = Dense(1, name='pulse')(d11_y)

    d10_r = Dense(nb_dense, activation='tanh', name='dense_resp')(d9)
    d11_r = Dropout(dropout_rate2, name='dropout_resp')(d10_r)
    out_r = Dense(1, name='resp')(d11_r)

    model = Model(inputs=[diff_input, rawf_input], outputs=[out_y, out_r])
    return model



def NEW_MODEL(n_frame, nb_filters1, nb_filters2, input_shape, kernel_size=(3, 3, 3), dropout_rate1=0.25, dropout_rate2=0.5,
            pool_size=(2, 2, 2), nb_dense=128, use_dataloader=False):

    diff_input_video = Input(shape=(20,36,36,3))
    rawf_input_video = Input(shape=(20,36,36,3))

    diff_input = Input(shape=(36,36,3))
    rawf_input = Input(shape=(36,36,3))

    nb_filters1=32
    nb_filters2=64
    kernel_size=(3,3)

    d1 = Conv2D(nb_filters1, (3,3), padding='same', activation='tanh')(diff_input)
    d2 = Conv2D(nb_filters1, (3,3), activation='tanh')(d1)

    # Appearance Branch
    r1 = Conv2D(nb_filters1, kernel_size, padding='same', activation='tanh')(rawf_input)
    r2 = Conv2D(nb_filters1, kernel_size, activation='tanh')(r1)
    g1 = Conv2D(1, (1, 1), padding='same', activation='sigmoid')(r2)
    g1 = Attention_mask()(g1)
    gated1 = multiply([d2, g1])
    d3 = AveragePooling3D(pool_size)(gated1)

    d4 = Dropout(dropout_rate1)(d3)
    d5 = Conv2D(nb_filters2, kernel_size, padding='same', activation='tanh')(d4)
    d6 = Conv2D(nb_filters2, kernel_size, activation='tanh')(d5)

    r3 = AveragePooling2D(pool_size)(r2)
    r4 = Dropout(dropout_rate1)(r3)
    r5 = Conv2D(nb_filters2, kernel_size, padding='same', activation='tanh')(r4)
    r6 = Conv2D(nb_filters2, kernel_size, activation='tanh')(r5)
    g2 = Conv2D(1, (1, 1), padding='same', activation='sigmoid')(r6)
    g2 = Attention_mask()(g2)
    gated2 = multiply([d6, g2])
    d7 = AveragePooling3D(pool_size)(gated2)
    cnn = Model(inputs=[diff_input, rawf_input], outputs=out)

    cnn.trainable = True

    concat = Concatenate([diff_input_video, rawf_input_video])
    out_model = TimeDistributed(Lambda(lambda x: cnn([x[:,:,:,0:2], x[:,:,:,3:5]])))(concat)


    encoded_frames = TimeDistributed(cnn)([diff_input_video, rawf_input_video])
    encoded_sequence = LSTM(256)(encoded_frames)

    return model

#input_shape = (36, 36, 10, 3)
#model = DeepPhy_3DCNN(10, 32, 64, input_shape)
#print('==========================')



#%% TSM Peak Detection + Motion Compensation
def TS_CAN_PEAKDETECTION(n_frame, nb_filters1, nb_filters2, input_shape, kernel_size=(3, 3), dropout_rate1=0.25,
                           dropout_rate2=0.5, pool_size=(2, 2), nb_dense=128):

    diff_input = Input(shape=input_shape)
    rawf_input = Input(shape=input_shape)

    # Pulse:
    d1 = TSM_Cov2D(diff_input, n_frame, nb_filters1, kernel_size, padding='same', activation='tanh')
    d2 = TSM_Cov2D(d1, n_frame, nb_filters1, kernel_size, padding='valid', activation='tanh')

    # Appearance:
    r1 = Conv2D(nb_filters1, kernel_size, padding='same', activation='tanh')(rawf_input)
    r2 = Conv2D(nb_filters1, kernel_size, activation='tanh')(r1)

    # Attention Pulse:
    g1 = Conv2D(1, (1, 1), padding='same', activation='sigmoid')(r2)
    g1 = Attention_mask()(g1)
    gated1 = multiply([d2, g1])

    # Pulse:
    d3 = AveragePooling2D(pool_size)(gated1)
    d4 = Dropout(dropout_rate1)(d3)

    # Appearance:
    r3 = AveragePooling2D(pool_size)(r2)
    r4 = Dropout(dropout_rate1)(r3)

    # Pulse:
    d5 = TSM_Cov2D(d4, n_frame, nb_filters2, kernel_size, padding='same', activation='tanh')
    d6 = TSM_Cov2D(d5, n_frame, nb_filters2, kernel_size, padding='valid', activation='tanh')

    # Appearance:
    r5 = Conv2D(nb_filters2, kernel_size, padding='same', activation='tanh')(r4)
    r6 = Conv2D(nb_filters2, kernel_size, activation='tanh')(r5)

    # Attention Pulse:
    g2 = Conv2D(1, (1, 1), padding='same', activation='sigmoid')(r6)
    g2 = Attention_mask()(g2)
    gated2 = multiply([d6, g2])

    # Pulse:
    d7 = AveragePooling2D(pool_size)(gated2)
    d8 = Dropout(dropout_rate1)(d7)
    d9 = Flatten()(d8)

    # Pulse:
    d10_y = Dense(nb_dense, activation='tanh')(d9)
    d11_y = Dropout(dropout_rate2)(d10_y)
    out_y = Dense(n_frame, name='pulse')(d11_y)

    # Peak:
    #d12_y = Dense(nb_dense, activation='tanh')(out_y)
    out_y2 = Dense(n_frame, name='resp')(out_y)

    model = Model(inputs=[diff_input, rawf_input], outputs=[out_y, out_y2])
    return model


#%% TSM Peak Detection + Motion Compensation
def TS_CAN_PEAKDETECTION_MOTION(n_frame, nb_filters1, nb_filters2, input_shape, kernel_size=(3, 3), dropout_rate1=0.25,
                           dropout_rate2=0.5, pool_size=(2, 2), nb_dense=128):

    diff_input = Input(shape=input_shape)
    rawf_input = Input(shape=input_shape)

    # Pulse:
    d1 = TSM_Cov2D(diff_input, n_frame, nb_filters1, kernel_size, padding='same', activation='tanh')
    d2 = TSM_Cov2D(d1, n_frame, nb_filters1, kernel_size, padding='valid', activation='tanh')

    # Motion:
    m1 = TSM_Cov2D(diff_input, n_frame, nb_filters1, kernel_size, padding='same', activation='tanh')
    m2 = TSM_Cov2D(m1, n_frame, nb_filters1, kernel_size, padding='valid', activation='tanh')

    # Appearance:
    r1 = Conv2D(nb_filters1, kernel_size, padding='same', activation='tanh')(rawf_input)
    r2 = Conv2D(nb_filters1, kernel_size, activation='tanh')(r1)

    # Attention Pulse:
    g1 = Conv2D(1, (1, 1), padding='same', activation='sigmoid')(r2)
    g1 = Attention_mask()(g1)
    gated1 = multiply([d2, g1])

    # Attention Motion:
    gm1 = Conv2D(1, (1, 1), padding='same', activation='sigmoid')(r2)
    gm1 = Attention_mask()(gm1)
    gated1_motion = multiply([m2, gm1])

    # Pulse:
    d3 = AveragePooling2D(pool_size)(gated1)
    d4 = Dropout(dropout_rate1)(d3)

    # Motion:
    m3 = AveragePooling2D(pool_size)(gated1_motion)
    m4 = Dropout(dropout_rate1)(m3)

    # Appearance:
    r3 = AveragePooling2D(pool_size)(r2)
    r4 = Dropout(dropout_rate1)(r3)

    # Pulse:
    d5 = TSM_Cov2D(d4, n_frame, nb_filters2, kernel_size, padding='same', activation='tanh')
    d6 = TSM_Cov2D(d5, n_frame, nb_filters2, kernel_size, padding='valid', activation='tanh')

    # Motion:
    m5 = TSM_Cov2D(m4, n_frame, nb_filters2, kernel_size, padding='same', activation='tanh')
    m6 = TSM_Cov2D(m5, n_frame, nb_filters2, kernel_size, padding='valid', activation='tanh')

    # Appearance:
    r5 = Conv2D(nb_filters2, kernel_size, padding='same', activation='tanh')(r4)
    r6 = Conv2D(nb_filters2, kernel_size, activation='tanh')(r5)

    # Attention Pulse:
    g2 = Conv2D(1, (1, 1), padding='same', activation='sigmoid')(r6)
    g2 = Attention_mask()(g2)
    gated2 = multiply([d6, g2])

    # Attention Motion:
    gm2 = Conv2D(1, (1, 1), padding='same', activation='sigmoid')(r6)
    gm2 = Attention_mask()(gm2)
    gated2_motion = multiply([m6, gm2])

    # Pulse:
    d7 = AveragePooling2D(pool_size)(gated2)
    d8 = Dropout(dropout_rate1)(d7)
    d9 = Flatten()(d8)

    # Motion:
    m7 = AveragePooling2D(pool_size)(gated2_motion)
    m8 = Dropout(dropout_rate1)(m7)
    m9 = Flatten()(m8)

    # Pulse:
    d10_y = Dense(nb_dense, activation='tanh')(d9)
    d11_y = Dropout(dropout_rate2)(d10_y)
    out_y = Dense(1, name='pulse')(d11_y)

    # Motion:
    m10_y = Dense(nb_dense, activation='tanh')(m9)
    m11_y = Dropout(dropout_rate2)(m10_y)
    out_y_motion = Dense(1, name='pulse')(m11_y)

    # Concatenate motion and pulse features before dense layer:
    combined_y = concatenate()([out_y, out_y_motion])

    

    model = Model(inputs=[diff_input, rawf_input], outputs=[out_y1, out_y2, out_y3])
    return model




# model = MTTS_CAN(20, 64, 128, (36, 36, 3))

#%%
def MTTS_CAN_Dual(n_frame, nb_filters1, nb_filters2, input_shape, kernel_size=(3, 3), dropout_rate1=0.25,
                           dropout_rate2=0.5, pool_size=(2, 2), nb_dense=128):

    diff_input = Input(shape=input_shape)
    rawf_input = Input(shape=input_shape)

    d1 = TSM_Cov2D(diff_input, n_frame, nb_filters1, kernel_size, padding='same', activation='tanh')
    d2 = TSM_Cov2D(d1, n_frame, nb_filters1, kernel_size, padding='valid', activation='tanh')

    r1 = Conv2D(nb_filters1, kernel_size, padding='same', activation='tanh')(rawf_input)
    r2 = Conv2D(nb_filters1, kernel_size, activation='tanh')(r1)

    g1 = Conv2D(1, (1, 1), padding='same', activation='sigmoid')(r2)
    g1 = Attention_mask()(g1)
    gated1 = multiply([d2, g1])

    g1_rr = Conv2D(1, (1, 1), padding='same', activation='sigmoid')(r2)
    g1_rr = Attention_mask()(g1_rr)
    gated1_rr = multiply([d2, g1_rr])


    d3 = AveragePooling2D(pool_size)(gated1)
    d4 = Dropout(dropout_rate1)(d3)

    d3_rr = AveragePooling2D(pool_size)(gated1_rr)
    d4_rr = Dropout(dropout_rate1)(d3_rr)

    r3 = AveragePooling2D(pool_size)(r2)
    r4 = Dropout(dropout_rate1)(r3)

    d5 = TSM_Cov2D(d4, n_frame, nb_filters2, kernel_size, padding='same', activation='tanh')
    d6 = TSM_Cov2D(d5, n_frame, nb_filters2, kernel_size, padding='valid', activation='tanh')

    d5_rr = TSM_Cov2D(d4_rr, n_frame, nb_filters2, kernel_size, padding='same', activation='tanh')
    d6_rr = TSM_Cov2D(d5_rr, n_frame, nb_filters2, kernel_size, padding='valid', activation='tanh')


    r5 = Conv2D(nb_filters2, kernel_size, padding='same', activation='tanh')(r4)
    r6 = Conv2D(nb_filters2, kernel_size, activation='tanh')(r5)

    g2 = Conv2D(1, (1, 1), padding='same', activation='sigmoid')(r6)
    g2 = Attention_mask()(g2)
    gated2 = multiply([d6, g2])

    g2_rr = Conv2D(1, (1, 1), padding='same', activation='sigmoid')(r6)
    g2_rr = Attention_mask()(g2_rr)
    gated2_rr = multiply([d6_rr, g2_rr])

    d7 = AveragePooling2D(pool_size)(gated2)
    d8 = Dropout(dropout_rate1)(d7)
    d9 = Flatten()(d8)

    d7_rr = AveragePooling2D(pool_size)(gated2_rr)
    d8_rr = Dropout(dropout_rate1)(d7_rr)
    d9_rr = Flatten()(d8_rr)

    d10_y = Dense(nb_dense, activation='tanh')(d9)
    d11_y = Dropout(dropout_rate2)(d10_y)
    out_y = Dense(1, name='pulse')(d11_y)

    d10_r = Dense(nb_dense, activation='tanh')(d9_rr)
    d11_r = Dropout(dropout_rate2)(d10_r)
    out_r = Dense(1, name='resp')(d11_r)

    model = Model(inputs=[diff_input, rawf_input], outputs=[out_y, out_r])
    return model


def MTTS_CAN_Dual_RNN(n_frame, nb_filters1, nb_filters2, input_shape, kernel_size=(3, 3), dropout_rate1=0.25,
                           dropout_rate2=0.5, pool_size=(2, 2), nb_dense=128):

    diff_input = Input(shape=input_shape)
    rawf_input = Input(shape=input_shape)

    d1 = TSM_Cov2D(diff_input, n_frame, nb_filters1, kernel_size, padding='same', activation='tanh')
    d2 = TSM_Cov2D(d1, n_frame, nb_filters1, kernel_size, padding='valid', activation='tanh')

    r1 = Conv2D(nb_filters1, kernel_size, padding='same', activation='tanh')(rawf_input)
    r2 = Conv2D(nb_filters1, kernel_size, activation='tanh')(r1)

    g1 = Conv2D(1, (1, 1), padding='same', activation='sigmoid')(r2)
    g1 = Attention_mask()(g1)
    gated1 = multiply([d2, g1])

    g1_rr = Conv2D(1, (1, 1), padding='same', activation='sigmoid')(r2)
    g1_rr = Attention_mask()(g1_rr)
    gated1_rr = multiply([d2, g1_rr])


    d3 = AveragePooling2D(pool_size)(gated1)
    d4 = Dropout(dropout_rate1)(d3)

    d3_rr = AveragePooling2D(pool_size)(gated1_rr)
    d4_rr = Dropout(dropout_rate1)(d3_rr)

    r3 = AveragePooling2D(pool_size)(r2)
    r4 = Dropout(dropout_rate1)(r3)

    d5 = TSM_Cov2D(d4, n_frame, nb_filters2, kernel_size, padding='same', activation='tanh')
    d6 = TSM_Cov2D(d5, n_frame, nb_filters2, kernel_size, padding='valid', activation='tanh')

    d5_rr = TSM_Cov2D(d4_rr, n_frame, nb_filters2, kernel_size, padding='same', activation='tanh')
    d6_rr = TSM_Cov2D(d5_rr, n_frame, nb_filters2, kernel_size, padding='valid', activation='tanh')


    r5 = Conv2D(nb_filters2, kernel_size, padding='same', activation='tanh')(r4)
    r6 = Conv2D(nb_filters2, kernel_size, activation='tanh')(r5)

    g2 = Conv2D(1, (1, 1), padding='same', activation='sigmoid')(r6)
    g2 = Attention_mask()(g2)
    gated2 = multiply([d6, g2], name="gated2")

    g2_rr = Conv2D(1, (1, 1), padding='same', activation='sigmoid')(r6)
    g2_rr = Attention_mask()(g2_rr)
    gated2_rr = multiply([d6_rr, g2_rr], name="gated2_rr")

    d7 = AveragePooling2D(pool_size)(gated2)
    d8 = Dropout(dropout_rate1)(d7)
    d9 = Flatten()(d8)

    d7_rr = AveragePooling2D(pool_size)(gated2_rr)
    d8_rr = Dropout(dropout_rate1)(d7_rr)
    d9_rr = Flatten()(d8_rr)

    d10_y = Dense(nb_dense, activation='tanh')(d9)
    d11_y = Dropout(dropout_rate2)(d10_y)
    out_y = Dense(1, name='pulse')(d11_y)

    d10_r = Dense(nb_dense, activation='tanh')(d9_rr)
    d11_r = Dropout(dropout_rate2)(d10_r)
    out_r = Dense(1, name='resp')(d11_r)

    model = Model(inputs=[diff_input, rawf_input], outputs=[out_y, out_r])
    return model

 #%% 3D-CAN

def DeepPhy_3DCNN(n_frame, nb_filters1, nb_filters2, input_shape, kernel_size=(3, 3, 3), dropout_rate1=0.25, dropout_rate2=0.5,
            pool_size=(2, 2, 2), nb_dense=128, use_dataloader=False):

    diff_input = Input(shape=input_shape)
    rawf_input = Input(shape=input_shape)

    d1 = Conv3D(nb_filters1, kernel_size, padding='same', activation='tanh')(diff_input)
    d2 = Conv3D(nb_filters1, kernel_size, activation='tanh')(d1)

    # Appearance Branch
    r1 = Conv3D(nb_filters1, kernel_size, padding='same', activation='tanh')(rawf_input)
    r2 = Conv3D(nb_filters1, kernel_size, activation='tanh')(r1)
    g1 = Conv3D(1, (1, 1, 1), padding='same', activation='sigmoid')(r2)
    g1 = Attention_mask()(g1)
    gated1 = multiply([d2, g1])
    d3 = AveragePooling3D(pool_size)(gated1)

    d4 = Dropout(dropout_rate1)(d3)
    d5 = Conv3D(nb_filters2, kernel_size, padding='same', activation='tanh')(d4)
    d6 = Conv3D(nb_filters2, kernel_size, activation='tanh')(d5)

    r3 = AveragePooling3D(pool_size)(r2)
    r4 = Dropout(dropout_rate1)(r3)
    r5 = Conv3D(nb_filters2, kernel_size, padding='same', activation='tanh')(r4)
    r6 = Conv3D(nb_filters2, kernel_size, activation='tanh')(r5)
    g2 = Conv3D(1, (1, 1, 1), padding='same', activation='sigmoid')(r6)
    g2 = Attention_mask()(g2)
    gated2 = multiply([d6, g2])
    d7 = AveragePooling3D(pool_size)(gated2)
    d8 = Dropout(dropout_rate1)(d7)
    d9 = Flatten()(d8)
    d10 = Dense(nb_dense, activation='tanh')(d9)
    d11 = Dropout(dropout_rate2)(d10)
    out = Dense(n_frame)(d11)
    model = Model(inputs=[diff_input, rawf_input], outputs=out)
    return model

input_shape = (36, 36, 10, 3)
model = DeepPhy_3DCNN(10, 32, 64, input_shape)
print('==========================')


#%% MT-3DCAN

def DeepPhy_3DCNN_MT(n_frame, nb_filters1, nb_filters2, input_shape, kernel_size=(3, 3, 3), dropout_rate1=0.25, dropout_rate2=0.5,
            pool_size=(2, 2, 2), nb_dense=128, use_dataloader=False):

    diff_input = Input(shape=input_shape)
    rawf_input = Input(shape=input_shape)

    d1 = Conv3D(nb_filters1, kernel_size, padding='same', activation='tanh')(diff_input)
    d2 = Conv3D(nb_filters1, kernel_size, activation='tanh')(d1)

    # Appearance Branch
    r1 = Conv3D(nb_filters1, kernel_size, padding='same', activation='tanh')(rawf_input)
    r2 = Conv3D(nb_filters1, kernel_size, activation='tanh')(r1)
    g1 = Conv3D(1, (1, 1, 1), padding='same', activation='sigmoid')(r2)
    g1 = Attention_mask()(g1)
    gated1 = multiply([d2, g1])

    d3 = AveragePooling3D(pool_size)(gated1)
    d4 = Dropout(dropout_rate1)(d3)
    d5 = Conv3D(nb_filters2, kernel_size, padding='same', activation='tanh')(d4)
    d6 = Conv3D(nb_filters2, kernel_size, activation='tanh')(d5)

    r3 = AveragePooling3D(pool_size)(r2)
    r4 = Dropout(dropout_rate1)(r3)
    r5 = Conv3D(nb_filters2, kernel_size, padding='same', activation='tanh')(r4)
    r6 = Conv3D(nb_filters2, kernel_size, activation='tanh')(r5)
    g2 = Conv3D(1, (1, 1, 1), padding='same', activation='sigmoid')(r6)
    g2 = Attention_mask()(g2)
    gated2 = multiply([d6, g2])
    d7 = AveragePooling3D(pool_size)(gated2)
    d8 = Dropout(dropout_rate1)(d7)

    d9 = Flatten()(d8)
    d10_y = Dense(nb_dense, activation='tanh')(d9)
    d11_y = Dropout(dropout_rate2)(d10_y)
    out_y = Dense(n_frame, name='pulse')(d11_y)

    d10_r = Dense(nb_dense, activation='tanh')(d9)
    d11_r = Dropout(dropout_rate2)(d10_r)
    out_r = Dense(n_frame, name='resp')(d11_r)

    model = Model(inputs=[diff_input, rawf_input], outputs=[out_y, out_r])

    return model


#%% Hybrid-CAN

def Hybrid_CAN(n_frame, nb_filters1, nb_filters2, input_shape_1, input_shape_2, kernel_size_1=(3, 3, 3), kernel_size_2=(3,3),
                dropout_rate1=0.25, dropout_rate2=0.5, pool_size_1=(2, 2, 2), pool_size_2=(2, 2),
                nb_dense=128, use_dataloader=False):

    diff_input = Input(shape=input_shape_1)
    rawf_input = Input(shape=input_shape_2)

    # Motion branch
    d1 = Conv3D(nb_filters1, kernel_size_1, padding='same', activation='tanh')(diff_input)
    d2 = Conv3D(nb_filters1, kernel_size_1, activation='tanh')(d1)

    # App branch
    r1 = Conv2D(nb_filters1, kernel_size_2, padding='same', activation='tanh')(rawf_input)
    r2 = Conv2D(nb_filters1, kernel_size_2, activation='tanh')(r1)

    # Mask from App (g1) * Motion Branch (d2)
    g1 = Conv2D(1, (1, 1), padding='same', activation='sigmoid')(r2)
    g1 = Attention_mask()(g1)
    g1 = K.expand_dims(g1, axis=-1)
    gated1 = multiply([d2, g1])

    # Motion Branch
    d3 = AveragePooling3D(pool_size_1)(gated1)
    d4 = Dropout(dropout_rate1)(d3)
    d5 = Conv3D(nb_filters2, kernel_size_1, padding='same', activation='tanh')(d4)
    d6 = Conv3D(nb_filters2, kernel_size_1, activation='tanh')(d5)

    # App branch
    r3 = AveragePooling2D(pool_size_2)(r2)
    r4 = Dropout(dropout_rate1)(r3)
    r5 = Conv2D(nb_filters2, kernel_size_2, padding='same', activation='tanh')(r4)
    r6 = Conv2D(nb_filters2, kernel_size_2, activation='tanh')(r5)

    # Mask from App (g2) * Motion Branch (d6)
    g2 = Conv2D(1, (1, 1), padding='same', activation='sigmoid')(r6)
    g2 = Attention_mask()(g2)
    g2 = K.repeat_elements(g2, d6.shape[3], axis=-1)
    g2 = K.expand_dims(g2, axis=-1)
    gated2 = multiply([d6, g2])

    # Motion Branch
    d7 = AveragePooling3D(pool_size_1)(gated2)
    d8 = Dropout(dropout_rate1)(d7)

    # Motion Branch
    d9 = Flatten()(d8)
    d10 = Dense(nb_dense, activation='tanh')(d9)
    d11 = Dropout(dropout_rate2)(d10)
    out = Dense(n_frame)(d11)

    model = Model(inputs=[diff_input, rawf_input], outputs=out)
    return model

 #%% MT-Hybrid-CAN

def Hybrid_CAN_MT(n_frame, nb_filters1, nb_filters2, input_shape_1, input_shape_2, kernel_size_1=(3, 3, 3),
                kernel_size_2=(3, 3), dropout_rate1=0.25, dropout_rate2=0.5, pool_size_1=(2, 2, 2), pool_size_2=(2, 2),
                nb_dense=128, use_dataloader=False):

    diff_input = Input(shape=input_shape_1)
    rawf_input = Input(shape=input_shape_2)

    # Motion branch
    d1 = Conv3D(nb_filters1, kernel_size_1, padding='same', activation='tanh')(diff_input)
    d2 = Conv3D(nb_filters1, kernel_size_1, activation='tanh')(d1)

    # App branch
    r1 = Conv2D(nb_filters1, kernel_size_2, padding='same', activation='tanh')(rawf_input)
    r2 = Conv2D(nb_filters1, kernel_size_2, activation='tanh')(r1)

    # Mask from App (g1) * Motion Branch (d2)
    g1 = Conv2D(1, (1, 1), padding='same', activation='sigmoid')(r2)
    g1 = Attention_mask()(g1)
    # g1 = K.repeat_elements(g1, d2.shape[3], axis=-1)
    g1 = K.expand_dims(g1, axis=-1)
    gated1 = multiply([d2, g1])

    # Motion Branch
    d3 = AveragePooling3D(pool_size_1)(gated1)
    d4 = Dropout(dropout_rate1)(d3)
    d5 = Conv3D(nb_filters2, kernel_size_1, padding='same', activation='tanh')(d4)
    d6 = Conv3D(nb_filters2, kernel_size_1, activation='tanh')(d5)

    # App branch
    r3 = AveragePooling2D(pool_size_2)(r2)
    r4 = Dropout(dropout_rate1)(r3)
    r5 = Conv2D(nb_filters2, kernel_size_2, padding='same', activation='tanh')(r4)
    r6 = Conv2D(nb_filters2, kernel_size_2, activation='tanh')(r5)

    # Mask from App (g2) * Motion Branch (d6)
    g2 = Conv2D(1, (1, 1), padding='same', activation='sigmoid')(r6)
    g2 = Attention_mask()(g2)
    g2 = K.repeat_elements(g2, d6.shape[3], axis=-1)
    g2 = K.expand_dims(g2, axis=-1)
    gated2 = multiply([d6, g2])

    # Motion Branch
    d7 = AveragePooling3D(pool_size_1)(gated2)
    d8 = Dropout(dropout_rate1)(d7)

    # Motion Branch
    d9 = Flatten()(d8)

    d10_y = Dense(nb_dense, activation='tanh')(d9)
    d11_y = Dropout(dropout_rate2)(d10_y)
    out_y = Dense(n_frame, name='pulse')(d11_y)

    d10_r = Dense(nb_dense, activation='tanh')(d9)
    d11_r = Dropout(dropout_rate2)(d10_r)
    out_r = Dense(n_frame, name='resp')(d11_r)

    model = Model(inputs=[diff_input, rawf_input], outputs=[out_y, out_r])
    return model

def Hybrid_CAN_MT_Dual(n_frame, nb_filters1, nb_filters2, input_shape_1, input_shape_2, kernel_size_1=(3, 3, 3),
                kernel_size_2=(3, 3), dropout_rate1=0.25, dropout_rate2=0.5, pool_size_1=(2, 2, 2), pool_size_2=(2, 2),
                nb_dense=128, use_dataloader=False):

    diff_input = Input(shape=input_shape_1)
    rawf_input = Input(shape=input_shape_2)

    # Motion branch
    d1 = Conv3D(nb_filters1, kernel_size_1, padding='same', activation='tanh')(diff_input)
    d2 = Conv3D(nb_filters1, kernel_size_1, activation='tanh')(d1)

    # App branch
    r1 = Conv2D(nb_filters1, kernel_size_2, padding='same', activation='tanh')(rawf_input)
    r2 = Conv2D(nb_filters1, kernel_size_2, activation='tanh')(r1)

    # Mask from App (g1) * Motion Branch (d2)
    g1 = Conv2D(1, (1, 1), padding='same', activation='sigmoid')(r2)
    g1 = Attention_mask()(g1)
    # g1 = K.repeat_elements(g1, d2.shape[3], axis=-1)
    g1 = K.expand_dims(g1, axis=-1)
    gated1 = multiply([d2, g1])

    # Motion Branch
    d3 = AveragePooling3D(pool_size_1)(gated1)
    d4 = Dropout(dropout_rate1)(d3)
    d5 = Conv3D(nb_filters2, kernel_size_1, padding='same', activation='tanh')(d4)
    d6 = Conv3D(nb_filters2, kernel_size_1, activation='tanh')(d5)

    # App branch
    r3 = AveragePooling2D(pool_size_2)(r2)
    r4 = Dropout(dropout_rate1)(r3)
    r5 = Conv2D(nb_filters2, kernel_size_2, padding='same', activation='tanh')(r4)
    r6 = Conv2D(nb_filters2, kernel_size_2, activation='tanh')(r5)

    # Mask from App (g2) * Motion Branch (d6)
    g2 = Conv2D(1, (1, 1), padding='same', activation='sigmoid')(r6)
    g2 = Attention_mask()(g2)
    g2 = K.repeat_elements(g2, d6.shape[3], axis=-1)
    g2 = K.expand_dims(g2, axis=-1)
    gated2 = multiply([d6, g2])


    g2_rr = Conv2D(1, (1, 1), padding='same', activation='sigmoid')(r6)
    g2_rr = Attention_mask()(g2_rr)
    g2_rr = K.repeat_elements(g2_rr, d6.shape[3], axis=-1)
    g2_rr = K.expand_dims(g2_rr, axis=-1)
    gated2_rr = multiply([d6, g2_rr])

    # Motion Branch
    d7 = AveragePooling3D(pool_size_1)(gated2)
    d8 = Dropout(dropout_rate1)(d7)
    d7_rr = AveragePooling3D(pool_size_1)(gated2_rr)
    d8_rr = Dropout(dropout_rate1)(d7_rr)

    # Motion Branch
    d9 = Flatten()(d8)
    d9_rr = Flatten()(d8_rr)

    d10_y = Dense(nb_dense, activation='tanh')(d9)
    d11_y = Dropout(dropout_rate2)(d10_y)
    out_y = Dense(n_frame, name='dysub')(d11_y)

    d10_r = Dense(nb_dense, activation='tanh')(d9_rr)
    d11_r = Dropout(dropout_rate2)(d10_r)
    out_r = Dense(n_frame, name='drsub')(d11_r)

    model = Model(inputs=[diff_input, rawf_input], outputs=[out_y, out_r])
    return model

# input_shape_1 = (36, 36, 10, 3) # Motion
# input_shape_2 = (36, 36, 3) # Apperance
#
# model = DeepPhysMix(32, 64, input_shape_1, input_shape_2)

def Hybrid_CAN_MT_Dual_RNN(n_frame, nb_filters1, nb_filters2, input_shape_1, input_shape_2, kernel_size_1=(3, 3, 3),
                kernel_size_2=(3, 3), dropout_rate1=0.25, dropout_rate2=0.5, pool_size_1=(2, 2, 2), pool_size_2=(2, 2),
                nb_dense=128, use_second_derivative=False, use_raw_signal=False, target_signals=["dysub", "drsub"]):

    diff_input = Input(shape=input_shape_1)
    rawf_input = Input(shape=input_shape_2)

    # Motion branch
    d1 = Conv3D(nb_filters1, kernel_size_1, padding='same', activation='tanh', name='d1')(diff_input)
    d2 = Conv3D(nb_filters1, kernel_size_1, padding='same', activation='tanh', name='d2')(d1)

    # App branch
    r1 = Conv2D(nb_filters1, kernel_size_2, padding='same', activation='tanh')(rawf_input)
    r2 = Conv2D(nb_filters1, kernel_size_2, padding='same', activation='tanh')(r1)

    # for each target signal
    outputs = []
    sig = "shared"
    # for sig in target_signals:
    # Mask from App (g1) * Motion Branch (d2)
    g1 = Conv2D(1, (1, 1), padding='same', activation='sigmoid')(r2)
    g1 = Attention_mask(name=f"attention_mask_1_{sig}")(g1)
    # g1 = K.repeat_elements(g1, d2.shape[3], axis=-1)
    g1 = K.expand_dims(g1, axis=-1)
    gated1 = multiply([d2, g1], name=f'gated1_{sig}')

    # Motion Branch
    d3 = AveragePooling3D(pool_size_1)(gated1)
    d4 = Dropout(dropout_rate1)(d3)
    d5 = Conv3D(nb_filters2, kernel_size_1, padding='same', activation='tanh', name=f'd5_{sig}')(d4)
    d6 = Conv3D(nb_filters2, kernel_size_1, padding='same', activation='tanh', name=f'd6_{sig}')(d5)

    # Appearance branch
    r3 = AveragePooling2D(pool_size_2)(r2)
    r4 = Dropout(dropout_rate1)(r3)
    r5 = Conv2D(nb_filters2, kernel_size_2, padding='same', activation='tanh')(r4)
    r6 = Conv2D(nb_filters2, kernel_size_2, padding='same', activation='tanh')(r5)

    for sig in target_signals:
        # create attention mask and multiply
        g2 = Conv2D(1, (1, 1), padding='same', activation='sigmoid')(r6)
        g2 = Attention_mask(name=f"attention_mask_2_{sig}")(g2)
        g2 = K.repeat_elements(g2, d6.shape[3], axis=-1)
        g2 = K.expand_dims(g2, axis=-1)
        gated2 = multiply([d6, g2], name=f'gated2_{sig}')

        # # Motion Branch
        d7 = AveragePooling3D((3, 3, 1), name=f'd7_{sig}')(gated2)
        d8 = Dropout(dropout_rate1)(d7)
        # # # Motion Branch
        d9 = AveragePooling3D((d8.shape[1], d8.shape[2], 1), name=f'd9_{sig}')(d8)
        # d8 = Conv3D(nb_filters1, kernel_size_1, strides=(3, 3, 1), padding='same', activation='tanh', name=f'd8_{sig}')(d8)
        # d9 = Conv3D(nb_filters1, (2, 2, 1), strides=(1, 1, 1), padding='valid', activation='tanh', name=f'd9_{sig}')(d8)

        # get rid of dummy dimensions
        d9 = Lambda(lambda x: x[:, 0, 0, :, :])(d9)
        enc_pulse = Bidirectional(GRU(nb_dense, return_sequences=True, dropout=dropout_rate1, name=f'enc_{sig}'))(d9)
        enc_pulse_2 = Bidirectional(GRU(nb_dense, return_sequences=True, dropout=dropout_rate1, name=f'enc_{sig}'))(enc_pulse)
        enc_pulse_2_sum = enc_pulse + enc_pulse_2   
        out_y = GRU(1, activation=None, dropout=dropout_rate1, return_sequences=True, name=sig)(enc_pulse_2_sum)
        outputs.append(out_y)

        if use_raw_signal:
            out_raw = GRU(1, activation=None, dropout=dropout_rate1, return_sequences=True, name=f'{sig}_raw')(enc_pulse)
            outputs.append(out_raw)
        
        if use_second_derivative:
            out_SD = GRU(1, activation=None, dropout=dropout_rate1, return_sequences=True, name=f'{sig}_SD')(enc_pulse[:, :-1, :])
            outputs.append(out_SD)

    model = Model(inputs=[diff_input, rawf_input], outputs=outputs)
    return model


def Hybrid_CAN_MT_Dual_RNN_v2(n_frame, nb_filters1, nb_filters2, input_shape_1, input_shape_2, kernel_size_1=(3, 3, 3),
                kernel_size_2=(3, 3), dropout_rate1=0.25, dropout_rate2=0.5, pool_size_1=(2, 2, 2), pool_size_2=(2, 2),
                nb_dense=128, use_dataloader=False):

    diff_input = Input(shape=input_shape_1, name='motion')
    rawf_input = Input(shape=input_shape_2, name='appearance')

    # Motion branch
    d1 = Conv3D(nb_filters1, kernel_size_1, padding='same', activation='tanh', name='d1')(diff_input)
    d2 = Conv3D(nb_filters1, kernel_size_1, padding='same', activation='tanh', name='d2')(d1)

    # App branch
    r1 = Conv2D(nb_filters1, kernel_size_2, padding='same', activation='tanh')(rawf_input)
    r2 = Conv2D(nb_filters1, kernel_size_2, padding='same', activation='tanh')(r1)

    # Mask from App (g1) * Motion Branch (d2)
    g1 = Conv2D(1, (1, 1), padding='same', activation='sigmoid')(r2)
    g1 = Attention_mask()(g1)
    print("g1 initial", g1.shape)
    # g1 = K.repeat_elements(g1, d2.shape[1], axis=1)
    g1 = K.expand_dims(g1, axis=3)
    print("d2 shape", d2.shape)
    print("g1 shape", g1.shape)
    gated1 = multiply([d2, g1], name='gated1')
    print("gated1 shape", gated1.shape)

    # Motion Branch
    d3 = AveragePooling3D(pool_size_1)(gated1)
    d4 = Dropout(dropout_rate1)(d3)
    d5 = Conv3D(nb_filters2, kernel_size_1, padding='same', activation='tanh', name='d5')(d4)
    d6 = Conv3D(nb_filters2, kernel_size_1, padding='same', activation='tanh', name='d6')(d5)

    # App branch
    r3 = AveragePooling2D(pool_size_2)(r2)
    r4 = Dropout(dropout_rate1)(r3)
    r5 = Conv2D(nb_filters2, kernel_size_2, padding='same', activation='tanh')(r4)
    r6 = Conv2D(nb_filters2, kernel_size_2, padding='same', activation='tanh')(r5)

    # Mask from App (g2) * Motion Branch (d6)
    g2 = Conv2D(1, (1, 1), padding='same', activation='sigmoid')(r6)
    g2 = Attention_mask()(g2)
    # g2 = K.repeat_elements(g2, d6.shape[4], axis=0)
    g2 = K.expand_dims(g2, axis=3)
    print("d6 shape", d6.shape)
    print("g2 shape", g2.shape)
    gated2 = multiply([d6, g2], name='gated2')


    g2_rr = Conv2D(1, (1, 1), padding='same', activation='sigmoid')(r6)
    g2_rr = Attention_mask()(g2_rr)
    # g2_rr = K.repeat_elements(g2_rr, d6.shape[4], axis=0)
    g2_rr = K.expand_dims(g2_rr, axis=3)
    gated2_rr = multiply([d6, g2_rr], name='gated2_rr')

    # Motion Branch
    d7 = AveragePooling3D((3, 3, 1), name='d7')(gated2)
    d8 = Dropout(dropout_rate1)(d7)
    d7_rr = AveragePooling3D((3, 3, 1), name='d7_rr')(gated2_rr)
    d8_rr = Dropout(dropout_rate1)(d7_rr)

    # Motion Branch
    d9 = AveragePooling3D((d8.shape[1], d8.shape[2], 1), name='d9')(d8)
    d9_rr = AveragePooling3D((d8_rr.shape[1], d8_rr.shape[2], 1), name='d9_rr')(d8_rr)

    # get rid of dummy dimensions
    d9 = Lambda(lambda x: x[:, 0, 0, :, :])(d9)
    d9_rr = Lambda(lambda x: x[:, 0, 0, :, :])(d9_rr)
    print("d9 shape", d9.shape)
    
    d10_y = Bidirectional(GRU(nb_dense, dropout=dropout_rate1, return_sequences=True))(d9)
    out_y = GRU(1, activation=None, dropout=dropout_rate1, return_sequences=True, name='dysub')(d10_y)

    d10_r = Bidirectional(GRU(nb_dense, dropout=dropout_rate1, return_sequences=True))(d9_rr)
    out_r = GRU(1, activation=None, dropout=dropout_rate1, return_sequences=True, name='drsub')(d10_r)

    model = Model(inputs=[diff_input, rawf_input], outputs=[out_y, out_r])
    return model


def Hybrid_CAN_MT_Dual_RNN_v3(n_frame, nb_filters1, nb_filters2, input_shape_1, input_shape_2, kernel_size_1=(3, 3, 3),
                kernel_size_2=(3, 3), dropout_rate1=0.25, dropout_rate2=0.5, pool_size_1=(2, 2, 2), pool_size_2=(2, 2),
                nb_dense=128, use_dataloader=False):

    diff_input = Input(shape=input_shape_1, name='motion')
    rawf_input = Input(shape=input_shape_2, name='appearance')

    # Motion branch
    d1 = Conv3D(nb_filters1, kernel_size_1, padding='same', activation='tanh', name='d1')(diff_input)
    d2 = Conv3D(nb_filters1, kernel_size_1, padding='same', activation='tanh', name='d2')(d1)

    # App branch
    r1 = Conv2D(nb_filters1, kernel_size_2, padding='same', activation='tanh')(rawf_input)
    r2 = Conv2D(nb_filters1, kernel_size_2, padding='same', activation='tanh')(r1)

    # Mask from App (g1) * Motion Branch (d2)
    g1 = Conv2D(1, (1, 1), padding='same', activation='sigmoid')(r2)
    combined_mask = Attention_mask(name='combined_mask')(g1)
    print("g1 initial", combined_mask.shape)
    # g1 = K.repeat_elements(g1, d2.shape[1], axis=1)
    g1 = K.expand_dims(combined_mask, axis=3)
    print("d2 shape", d2.shape)
    print("g1 shape", g1.shape)
    gated1 = multiply([d2, g1], name='gated1')
    print("gated1 shape", gated1.shape)

    # Motion Branch
    d3 = AveragePooling3D(pool_size_1)(gated1)
    d4 = Dropout(dropout_rate1)(d3)
    d5 = Conv3D(nb_filters2, kernel_size_1, padding='same', activation='tanh', name='d5')(d4)
    d6 = Conv3D(nb_filters2, kernel_size_1, padding='same', activation='tanh', name='d6')(d5)

    # App branch
    r3 = AveragePooling2D(pool_size_2)(r2)
    r4 = Dropout(dropout_rate1)(r3)
    r5 = Conv2D(nb_filters2, kernel_size_2, padding='same', activation='tanh')(r4)
    r6 = Conv2D(nb_filters2, kernel_size_2, padding='same', activation='tanh')(r5)

    # Mask from App (g2) * Motion Branch (d6)
    g2 = Conv2D(1, (1, 1), padding='same', activation='sigmoid')(r6)
    skin_mask = Attention_mask(name='skin_mask')(g2)
    # g2 = K.repeat_elements(g2, d6.shape[4], axis=0)
    g2 = K.expand_dims(skin_mask, axis=3)
    print("d6 shape", d6.shape)
    print("g2 shape", g2.shape)
    gated2 = multiply([d6, g2], name='gated2')


    g2_rr = Conv2D(1, (1, 1), padding='same', activation='sigmoid')(r6)
    g2_rr = Attention_mask()(g2_rr)
    # g2_rr = K.repeat_elements(g2_rr, d6.shape[4], axis=0)
    g2_rr = K.expand_dims(g2_rr, axis=3)
    gated2_rr = multiply([d6, g2_rr], name='gated2_rr')

    # Motion Branch
    d7 = AveragePooling3D((3, 3, 1), name='d7')(gated2)
    d8 = Dropout(dropout_rate1)(d7)
    d7_rr = AveragePooling3D((3, 3, 1), name='d7_rr')(gated2_rr)
    d8_rr = Dropout(dropout_rate1)(d7_rr)

    # Motion Branch
    d9 = AveragePooling3D((6, 6, 1), name='d9')(d8)
    d9_rr = AveragePooling3D((6, 6, 1), name='d9_rr')(d8_rr)

    # get rid of dummy dimensions
    d9 = Lambda(lambda x: x[:, 0, 0, :, :])(d9)
    d9_rr = Lambda(lambda x: x[:, 0, 0, :, :])(d9_rr)
    print("d9 shape", d9.shape)
    
    d10_y = Bidirectional(GRU(nb_dense, dropout=dropout_rate1, return_sequences=True))(d9)
    out_y = GRU(1, activation=None, dropout=dropout_rate1, return_sequences=True, name='dysub')(d10_y)

    d10_r = Bidirectional(GRU(nb_dense, dropout=dropout_rate1, return_sequences=True))(d9_rr)
    out_r = GRU(1, activation=None, dropout=dropout_rate1, return_sequences=True, name='drsub')(d10_r)

    model = Model(inputs=[diff_input, rawf_input], outputs=[out_y, out_r, combined_mask, skin_mask])
    return model

from tensorflow.keras.layers import ConvLSTM2D
def Hybrid_CAN_MT_Dual_RNN_v4(n_frame, nb_filters1, nb_filters2, input_shape_1, input_shape_2, kernel_size_1=(3, 3, 3),
                kernel_size_2=(3, 3), dropout_rate1=0.25, dropout_rate2=0.5, pool_size_1=(2, 2, 2), pool_size_2=(2, 2),
                nb_dense=128, use_second_derivative=False, use_raw_signal=False, target_signals=["dysub", "drsub"]):

    diff_input = Input(shape=input_shape_1)
    rawf_input = Input(shape=input_shape_2)

    # Motion branch
    d1 = ConvLSTM2D(nb_filters1, kernel_size_1[:-1], padding='same', activation='tanh', dropout=dropout_rate1, return_sequences=True, name='d1')(diff_input)
    d2 = ConvLSTM2D(nb_filters1, kernel_size_1[:-1], padding='same', activation='tanh', dropout=dropout_rate1, return_sequences=True, name='d2')(d1)

    # App branch
    r1 = Conv2D(nb_filters1, kernel_size_2, padding='same', activation='tanh')(rawf_input)
    r2 = Conv2D(nb_filters1, kernel_size_2, padding='same', activation='tanh')(r1)

    # for each target signal
    outputs = []
    sig = "shared"
    # for sig in target_signals:
    # Mask from App (g1) * Motion Branch (d2)
    g1 = Conv2D(1, (1, 1), padding='same', activation='sigmoid')(r2)
    g1 = Attention_mask(name=f"attention_mask_1_{sig}")(g1)
    # g1 = K.repeat_elements(g1, d2.shape[3], axis=-1)
    g1 = K.expand_dims(g1, axis=0)
    gated1 = multiply([d2, g1], name=f'gated1_{sig}')

    # Motion Branch
    print("gated1 shape", gated1.shape)
    d3 = AveragePooling3D((1, 2, 2), data_format="channels_last")(gated1)
    print("d3 shape", d3.shape)
    d4 = Dropout(dropout_rate1)(d3)
    d5 = ConvLSTM2D(nb_filters2, kernel_size_1[:-1], padding='same', activation='tanh', dropout=dropout_rate1, return_sequences=True, name=f'd5_{sig}')(d4)
    d6 = ConvLSTM2D(nb_filters2, kernel_size_1[:-1], padding='same', activation='tanh', dropout=dropout_rate1, return_sequences=True, name=f'd6_{sig}')(d5)
    print("d6 shape:", d6.shape)
    # Appearance branch
    r3 = AveragePooling2D(pool_size_2)(r2)
    r4 = Dropout(dropout_rate1)(r3)
    r5 = Conv2D(nb_filters2, kernel_size_2, padding='same', activation='tanh')(r4)
    r6 = Conv2D(nb_filters2, kernel_size_2, padding='same', activation='tanh')(r5)

    for sig in target_signals:
        # create attention mask and multiply
        g2 = Conv2D(1, (1, 1), padding='same', activation='sigmoid')(r6)
        g2 = Attention_mask(name=f"attention_mask_2_{sig}")(g2)
        # g2 = K.repeat_elements(g2, d6.shape[3], axis=-1)
        g2 = K.expand_dims(g2, axis=0)
        gated2 = multiply([d6, g2], name=f'gated2_{sig}')

        # # Motion Branch
        d7 = AveragePooling3D((1, 3, 3), name=f'd7_{sig}')(gated2)
        d8 = Dropout(dropout_rate1)(d7)
        # # # Motion Branch
        d9 = AveragePooling3D((1, d8.shape[2], d8.shape[3]), name=f'd9_{sig}')(d8)
        # d8 = Conv3D(nb_filters1, kernel_size_1, strides=(3, 3, 1), padding='same', activation='tanh', name=f'd8_{sig}')(d8)
        # d9 = Conv3D(nb_filters1, (2, 2, 1), strides=(1, 1, 1), padding='valid', activation='tanh', name=f'd9_{sig}')(d8)

        # get rid of dummy dimensions
        d9 = Lambda(lambda x: x[:, :, 0, 0, :])(d9)
        enc_pulse = Bidirectional(GRU(nb_dense, return_sequences=True, dropout=dropout_rate1, name=f'enc_{sig}'))(d9)
        out_y = GRU(1, activation=None, dropout=dropout_rate1, return_sequences=True, name=f"{sig}_FD")(enc_pulse)
        out_y = Lambda(lambda x: x[:, :, 0], name=sig)(out_y)
        outputs.append(out_y)

        if use_raw_signal:
            out_raw = GRU(1, activation=None, dropout=dropout_rate1, return_sequences=True, name=f'{sig}_raw')(enc_pulse)
            outputs.append(out_raw)
        
        if use_second_derivative:
            out_SD = GRU(1, activation=None, dropout=dropout_rate1, return_sequences=True, name=f'{sig}_SD')(enc_pulse[:, :-1, :])
            outputs.append(out_SD)

    model = Model(inputs=[diff_input, rawf_input], outputs=outputs)
    return model

def Hybrid_CAN_MT_Dual_RNN_MD(n_frame, nb_filters1, nb_filters2, input_shape_1, input_shape_2, input_shape_3, kernel_size_1=(3, 3, 3),
                kernel_size_2=(3, 3), dropout_rate1=0.25, dropout_rate2=0.5, pool_size_1=(2, 2, 2), pool_size_2=(2, 2, 1),
                nb_dense=128, 
                use_second_derivative_frames=False,
                use_second_derivative_frames_only=False, 
                predict_raw_signal=False, 
                predict_first_derivative=True,
                predict_second_derivative=False,
                target_signals=["dysub", "drsub"]):

    zeroth_deriv_input = Input(shape=input_shape_1)
    first_deriv_input = Input(shape=input_shape_2)
    second_deriv_input = Input(shape=input_shape_3)

    # Attention mask branch
    m1 = Conv3D(nb_filters1, kernel_size_1, padding='same', activation='tanh')(zeroth_deriv_input)
    m2 = Conv3D(nb_filters1, kernel_size_1, padding='same', activation='tanh')(m1)

    # # Appearance branch
    # r1 = Conv3D(nb_filters1, kernel_size_1, padding='same', activation='tanh')(first_deriv_input)
    # r2 = Conv3D(nb_filters1, kernel_size_1, padding='same', activation='tanh')(r1)

    # Motion branch
    d1 = Conv3D(nb_filters1, kernel_size_1, padding='same', activation='tanh', name='d1')(first_deriv_input)
    d2 = Conv3D(nb_filters1, kernel_size_1, padding='same', activation='tanh', name='d2')(d1)

    # for each target signal
    outputs = []
    sig = "shared"
    # for sig in target_signals:
    # Mask from App (g1) * Motion Branch (d2)
    g1 = Conv3D(1, (1, 1, 1), padding='same', activation='sigmoid')(m2)
    g1 = Attention_mask(name=f"attention_mask_1_{sig}")(g1)
    # g1 = K.repeat_elements(g1, d2.shape[3], axis=-1)
    # g1 = K.expand_dims(g1, axis=-1)
    # gated1_attention = multiply([m2, g1], name=f'gated1_attention_{sig}')
    # gated1_appear = multiply([r2, g1], name=f'gated1_appear_{sig}')
    gated1_motion = multiply([d2, g1], name=f'gated1_motion_{sig}')

    # Attention branch
    m3 = AveragePooling3D(pool_size_2)(m2)
    m4 = Dropout(dropout_rate1)(m3)
    m5 = Conv3D(nb_filters2, kernel_size_1, padding='same', activation='tanh')(m4)
    m6 = Conv3D(nb_filters2, kernel_size_1, padding='same', activation='tanh')(m5)

    # # Appearance branch
    # r3 = AveragePooling3D(pool_size_2)(gated1_appear)
    # r4 = Dropout(dropout_rate1)(r3)
    # r5 = Conv3D(nb_filters2, kernel_size_1, padding='same', activation='tanh')(r4)
    # r6 = Conv3D(nb_filters2, kernel_size_1, padding='same', activation='tanh')(r5)

    # Motion Branch
    d3 = AveragePooling3D(pool_size_2)(gated1_motion)
    d4 = Dropout(dropout_rate1)(d3)
    d5 = Conv3D(nb_filters2, kernel_size_1, padding='same', activation='tanh', name=f'd5_{sig}')(d4)
    d6 = Conv3D(nb_filters2, kernel_size_1, padding='same', activation='tanh', name=f'd6_{sig}')(d5)

    if use_second_derivative_frames:
        # Accel branch
        a1 = Conv3D(nb_filters1, kernel_size_1, padding='same', activation='tanh')(second_deriv_input)
        a2 = Conv3D(nb_filters1, kernel_size_1, padding='same', activation='tanh')(a1)
        # apply attention masks
        gated1_accel = multiply([a2, g1[:, :, :, :-1, :]], name=f'gated1_accel_{sig}')
        a3 = AveragePooling3D(pool_size_2)(gated1_accel)
        a4 = Dropout(dropout_rate1)(a3)
        a5 = Conv3D(nb_filters2, kernel_size_1, padding='same', activation='tanh')(a4)
        a6 = Conv3D(nb_filters2, kernel_size_1, padding='same', activation='tanh')(a5)

    for sig in target_signals:
        # create attention mask and multiply
        g2 = Conv3D(1, (1, 1, 1), padding='same', activation='sigmoid')(m6)
        g2 = Attention_mask(name=f"attention_mask_2_{sig}")(g2)
        # g2 = K.repeat_elements(g2, d6.shape[3], axis=-1)
        # g2 = K.expand_dims(g2, axis=-1)
        # gated2_appear = multiply([r6, g2], name=f'gated2_appear_{sig}')
        gated2_motion = multiply([d6, g2], name=f'gated2_motion_{sig}')
        

        # # # Appearance Branch
        # r9 = AveragePooling3D((gated2_appear.shape[1], gated2_appear.shape[2], 1), name=f'r9_{sig}')(gated2_appear)
        # # get rid of dummy dimensions
        # r9 = Lambda(lambda x: x[:, 0, 0, :, :])(r9)

        # # Motion Branch
        d9 = AveragePooling3D((gated2_motion.shape[1], gated2_motion.shape[2], 1), name=f'd9_{sig}')(gated2_motion)
        # get rid of dummy dimensions
        d9 = Lambda(lambda x: x[:, 0, 0, :, :])(d9)

        if use_second_derivative_frames:
            gated2_accel = multiply([a6, g2[:, :, :, :-1, :]], name=f'gated2_accel_{sig}')
            # # Accel Branch
            a9 = AveragePooling3D((gated2_accel.shape[1], gated2_accel.shape[2], 1), name=f'a9_{sig}')(gated2_accel)
            # get rid of dummy dimensions
            a9 = Lambda(lambda x: x[:, 0, 0, :, :])(a9)
            a9 = tf.keras.layers.ZeroPadding1D(padding=(0, 1))(a9)

            concat = tf.keras.layers.Concatenate(axis=-1)([d9, a9])
        else:
            concat = d9
        # if only use acceleration frames, overwrite concat variable
        if use_second_derivative_frames_only:
            concat = a9

        if predict_raw_signal:
            # Appearance waveform
            enc_pulse_appear = Bidirectional(GRU(nb_dense, return_sequences=True, dropout=dropout_rate1, name=f'enc_appear_{sig}'))(concat)
            out_y_appear = GRU(1, activation=None, dropout=dropout_rate1, return_sequences=True, name=f'{sig}_raw')(enc_pulse_appear)
            outputs.append(out_y_appear)

        if predict_first_derivative:
            # Motion waveform
            enc_pulse_motion = Bidirectional(GRU(nb_dense, return_sequences=True, dropout=dropout_rate1, name=f'enc_motion_{sig}'))(concat)
            out_y_motion = GRU(1, activation=None, dropout=dropout_rate1, return_sequences=True, name=sig)(enc_pulse_motion)
            outputs.append(out_y_motion)

        if predict_second_derivative:
            # Accel waveform
            enc_pulse_accel = Bidirectional(GRU(nb_dense, return_sequences=True, dropout=dropout_rate1, name=f'enc_accel_{sig}'))(concat[:, :-1, :])
            out_y_accel = GRU(1, activation=None, dropout=dropout_rate1, return_sequences=True, name=f'{sig}_SD')(enc_pulse_accel)
            outputs.append(out_y_accel)

    model = Model(inputs=[zeroth_deriv_input, first_deriv_input, second_deriv_input], outputs=outputs)
    return model

#%%
def DeepPhysMotion(n_frame, nb_filters1, nb_filters2, input_shape_1, kernel_size_1=(3, 3, 3),
                dropout_rate1=0.25, dropout_rate2=0.5, pool_size_1=(2, 2, 2), nb_dense=128):

    diff_input = Input(shape=input_shape_1)

    # Motion branch
    d1 = Conv3D(nb_filters1, kernel_size_1, padding='same', activation='tanh')(diff_input)
    d2 = Conv3D(nb_filters1, kernel_size_1, activation='tanh')(d1)

    # Motion Branch
    d3 = AveragePooling3D(pool_size_1)(d2)
    d4 = Dropout(dropout_rate1)(d3)
    d5 = Conv3D(nb_filters2, kernel_size_1, padding='same', activation='tanh')(d4)
    d6 = Conv3D(nb_filters2, kernel_size_1, activation='tanh')(d5)

    # Motion Branch
    d7 = AveragePooling3D(pool_size_1)(d6)
    d8 = Dropout(dropout_rate1)(d7)

    # Motion Branch
    d9 = Flatten()(d8)
    d10 = Dense(nb_dense, activation='tanh')(d9)
    d11 = Dropout(dropout_rate2)(d10)
    out = Dense(n_frame)(d11)

    model = Model(inputs=[diff_input], outputs=out)
    return model



#%% TSM without Attention and App branch

def DeepPhys_TSM_Motion(n_frame, nb_filters1, nb_filters2, input_shape, kernel_size=(3, 3), dropout_rate1=0.25, dropout_rate2=0.5,
            pool_size=(2, 2), nb_dense=128):

    diff_input = Input(shape=input_shape)
    rawf_input = Input(shape=input_shape)

    d1 = TSM_Cov2D(diff_input, n_frame, nb_filters1, kernel_size, padding='same', activation='tanh')
    d2 = TSM_Cov2D(d1, n_frame, nb_filters1, kernel_size, padding='valid', activation='tanh')

    d3 = AveragePooling2D(pool_size)(d2)
    d4 = Dropout(dropout_rate1)(d3)

    d5 = TSM_Cov2D(d4, n_frame, nb_filters2, kernel_size, padding='same', activation='tanh')
    d6 = TSM_Cov2D(d5, n_frame, nb_filters2, kernel_size, padding='valid', activation='tanh')

    d7 = AveragePooling2D(pool_size)(d6)
    d8 = Dropout(dropout_rate1)(d7)

    d9 = Flatten()(d8)
    d10 = Dense(nb_dense, activation='tanh')(d9)
    d11 = Dropout(dropout_rate2)(d10)
    out = Dense(1)(d11)
    model = Model(inputs=[diff_input, rawf_input], outputs=out)
    return model
#%%
class HeartBeat(keras.callbacks.Callback):
    def __init__(self, train_gen, test_gen, args, cv_split, save_dir):
        super(HeartBeat, self).__init__()
        self.train_gen = train_gen
        self.test_gen = test_gen
        self.args = args
        self.cv_split = cv_split
        self.save_dir = save_dir

    def on_epoch_end(self, epoch, logs={}):
        # if epoch > 45:
        #     print(' | Predicting HeartBeat')
        #     yptrain = self.model.predict(self.train_gen, verbose=1)
        #     scipy.io.savemat(self.save_dir + '/yptrain' + str(epoch) + '_cv' + self.cv_split +'.mat',
        #                      mdict={'yptrain': yptrain})
        #     yptest = self.model.predict(self.test_gen, verbose=1)
        #     scipy.io.savemat(self.save_dir + '/yptest' + str(epoch) + '_cv' + self.cv_split + '.mat',
        #                      mdict={'yptest': yptest})
        print('PROGRESS: 0.00%')
