from tensorflow.keras.models import Sequential, Model
from tensorflow.keras.layers import Dense, Activation, Flatten, Dropout, BatchNormalization, Conv1D
from tensorflow.keras.layers import Input, MaxPooling1D, Multiply, LSTM

def model_Conv1D(dim, win_len, num_classes, num_feat_map=128, p=0.3):
    model = Sequential()
    model.add(Conv1D(num_feat_map, kernel_size=3, activation='relu', padding='same', 
                     input_shape=(win_len, dim),name='Conv_1'))
    model.add(MaxPooling1D(pool_size=4, name='Max_pool_1'))
    model.add(BatchNormalization(name='Bn_1'))
    model.add(Dropout(p, name='Drop_1'))
    model.add(Conv1D(16, kernel_size=3, activation='relu', padding='same',name='Conv_2'))
    model.add(BatchNormalization(name='Bn_2'))
    model.add(Dropout(p, name='Drop_2'))
    model.add(Flatten(name = 'flatten'))
    model.add(Dense(300, activation='relu'))
    model.add(BatchNormalization(name='Bn_3'))
    model.add(Dropout(p, name='Drop_3'))
    model.add(Dense(num_classes, name='logits'))
    model.add(Activation('softmax', name = 'probs'))
    return model

def model_Conv1D_concepts(dim, win_len, num_classes, n_concepts, num_feat_map=64, p=0.3):
    inputs = Input(shape=(win_len, dim), name='Input_1')
    x = Conv1D(128, kernel_size=3, activation='relu', padding='same',name='Conv_1')(inputs)
    x = MaxPooling1D(pool_size=4, name='Max_pool_1')(x)
    x = BatchNormalization(name='Bn_1')(x)
    x = Dropout(p, name='Drop_1')(x)
    x = Conv1D(16, kernel_size=3, activation='relu', padding='same',name='Conv_2')(x)
    x = BatchNormalization(name='Bn_2')(x)
    x = Dropout(p, name='Drop_2')(x)
    x = Flatten(name = 'flatten')(x)
    concepts = Dense(n_concepts, name='concept_logits')(x)
    concepts = Activation('sigmoid', name = 'c_probs')(concepts)
    out = Dense(num_classes, name='logits')(concepts)
    out = Activation('softmax', name = 'probs')(out)
    
    model = Model(inputs=inputs, outputs=[concepts, out], name="Video_concepts")
    return model

def model_Conv1D_attn_concepts(dim, win_len, num_classes, n_concepts, num_feat_map=128, p=0.3):
    inputs = Input(shape=(win_len, dim), name='Input_1')
    x = Conv1D(num_feat_map, kernel_size=3, activation='relu', padding='same',name='Conv_1')(inputs)
    x = MaxPooling1D(pool_size=4, name='Max_pool_1')(x)
    x = BatchNormalization(name='Bn_1')(x)
    x = Dropout(p, name='Drop_1')(x)
    x = Conv1D(16, kernel_size=3, activation='relu', padding='same',name='Conv_2')(x)
    x = BatchNormalization(name='Bn_2')(x)
    x = Dropout(p, name='Drop_2')(x)
    x = Flatten(name = 'flatten')(x)
    concepts = Dense(n_concepts, name='concept_logits')(x)
    concepts = Activation('sigmoid', name = 'c_probs')(concepts)
    
    attention = Dense(n_concepts, name = 'attention_weights', activation='tanh')(concepts)
    attention = Activation('softmax', name='attn_score')(attention)
    
    out = Multiply(name='mul')([attention,concepts])
    out = Dense(num_classes, name='logits')(out)
    out = Activation('softmax', name = 'probs')(out)
    
    model = Model(inputs=inputs, outputs=[concepts, out], name="Video_concepts")
    return model

def model_LSTM(dim, win_len, num_classes, num_hidden_lstm=128, p=0.3):
    inputs = Input(shape=(win_len, dim), name='Input_1')
    x = LSTM(num_hidden_lstm, return_sequences=True,name='lstm_1')(inputs)
    x = MaxPooling1D(pool_size=4, name='Max_pool_1')(x)
    x = BatchNormalization(name='Bn_1')(x)
    x = Dropout(p, name='Drop_1')(x)
    x = LSTM(16,return_sequences=True,name='lstm_s')(x)
    x = BatchNormalization(name='Bn_2')(x)
    x = Dropout(p, name='Drop_2')(x)
    x = Flatten(name = 'flatten')(x)
    x = Dense(300, activation='relu')(x)
    x = BatchNormalization(name='Bn_3')(x)
    x = Dropout(p, name='Drop_3')(x)
    out = Dense(num_classes, name='logits')(x)
    out = Activation('softmax', name = 'probs')(out)
    
    model = Model(inputs=inputs, outputs=out, name="Video_concepts")
    return model

def model_LSTM_concepts(dim, win_len, num_classes, n_concepts, num_hidden_lstm=128, p=0.3):
    inputs = Input(shape=(win_len, dim), name='Input_1')
    x = LSTM(num_hidden_lstm, return_sequences=True,name='lstm_1')(inputs)
    x = MaxPooling1D(pool_size=4, name='Max_pool_1')(x)
    x = BatchNormalization(name='Bn_1')(x)
    x = Dropout(p, name='Drop_1')(x)
    x = LSTM(16,return_sequences=True,name='lstm_s')(x)
    x = BatchNormalization(name='Bn_2')(x)
    x = Dropout(p, name='Drop_2')(x)
    x = Flatten(name = 'flatten')(x)
    concepts = Dense(n_concepts, name='concept_logits')(x)
    concepts = Activation('sigmoid', name = 'c_probs')(concepts)
    out = Dense(num_classes, name='logits')(concepts)
    out = Activation('softmax', name = 'probs')(out)
    
    model = Model(inputs=inputs, outputs=[concepts, out], name="Video_concepts")
    return model


def model_LSTM_attn_concepts(dim, win_len, num_classes, n_concepts, num_hidden_lstm=128, p=0.3):
    inputs = Input(shape=(win_len, dim), name='Input_1')
    x = LSTM(num_hidden_lstm, return_sequences=True,name='Lstm_1')(inputs)
    x = MaxPooling1D(pool_size=4, name='Max_pool_1')(x)
    x = BatchNormalization(name='Bn_1')(x)
    x = Dropout(p, name='Drop_1')(x)
    x = LSTM(16,return_sequences=True,name='Lstm_2')(x)
    x = BatchNormalization(name='Bn_2')(x)
    x = Dropout(p, name='Drop_2')(x)
    
    x = Flatten(name = 'flatten')(x)
    concepts = Dense(n_concepts, name='concept_logits')(x)
    concepts = Activation('sigmoid', name = 'c_probs')(concepts)
    
    attention = Dense(n_concepts, name = 'attention_weights', activation='tanh')(concepts)
    attention = Activation('softmax', name='attn_score')(attention)
    
    out = Multiply(name='mul')([attention,concepts])
    out = Dense(num_classes, name='logits')(out)
    out = Activation('softmax', name = 'probs')(out)
    
    model = Model(inputs=inputs, outputs=[concepts, out], name="Video_concepts")
    return model