import numpy as np
import os
import tensorflow as tf
from tensorflow.keras.callbacks import ModelCheckpoint, TensorBoard
from tensorflow.keras.models import Model
from sklearn.metrics import f1_score, accuracy_score, confusion_matrix
from sklearn.metrics.pairwise import cosine_similarity
import matplotlib.pyplot as plt

def train_model(model, X_train, y_train, X_val, y_val, model_dir, t, name, batch_size=32, epochs=50):
    
    model.compile(loss=tf.keras.losses.categorical_crossentropy,
                  optimizer='adam',
                  metrics=['accuracy'])

    # checkpoint
    chk_path = os.path.join(model_dir, 'best_{}_{}.h5'.format(name,t))
    checkpoint = ModelCheckpoint(chk_path, monitor='val_accuracy', verbose=1, save_best_only=True, mode='max')
    tensorboard = TensorBoard(log_dir="logs/{}_{}".format(name,t))
    callbacks_list = [checkpoint, tensorboard]

    history = model.fit(X_train, y_train,
                batch_size=batch_size,
                epochs=epochs,
                verbose=1,
                shuffle=True,
                validation_data=(X_val, y_val),
                callbacks=callbacks_list)
    
    #Saving the model
    model.save(os.path.join(model_dir, 'final_{}_{}.h5'.format(NAME,t)))
    
    return model, history

def train_concept_model(model, X_train, y_train, c_train, X_val, y_val, c_val, 
                         model_dir, t, n_concepts, name, batch_size=32, epochs=50):
    
    
    losses={
        "c_probs": tf.keras.losses.binary_crossentropy,
        "probs": tf.keras.losses.categorical_crossentropy,
    }
    
    model.compile(loss=losses,
                  optimizer='adam',
                  metrics=['accuracy'])

    # checkpoint
    chk_path = os.path.join(model_dir, 'best_{}_{}_{}.h5'.format(name,n_concepts,t))
    checkpoint = ModelCheckpoint(chk_path, monitor='val_probs_accuracy',
                                 verbose=1, save_best_only=True, mode='max')
    tensorboard = TensorBoard(log_dir="logs/{}_{}".format(name,t))
    callbacks_list = [checkpoint, tensorboard]

    history = model.fit(X_train, {'probs':y_train, 'c_probs':c_train},
                batch_size=batch_size,
                epochs=epochs,
                verbose=1,
                shuffle=True,
                validation_data=(X_val, {'probs':y_val, 'c_probs':c_val}),
                callbacks=callbacks_list)
    
    #Saving the model
    model.save(os.path.join(model_dir, 'final_{}_{}_{}.h5'.format(name,n_concepts,t)))
    
    return model, history


def calculate_metrics(model, X_test, y_test_binary):
    y_pred = np.argmax(model.predict(X_test), axis=1)
    y_true = np.argmax(y_test_binary, axis=1)
    mismatch = np.where(y_true != y_pred)
    cf_matrix = confusion_matrix(y_true, y_pred)
    accuracy = accuracy_score(y_true, y_pred)
    #micro_f1 = f1_score(y_true, y_pred, average='micro')
    macro_f1 = f1_score(y_true, y_pred, average='macro')
    
    return cf_matrix, accuracy, macro_f1, mismatch, y_pred

def calculate_concept_metrics(model, X_test, y_test_binary, c_test):
    pred = model.predict(X_test)
    y_pred = np.argmax(pred[1], axis=1)
    y_true = np.argmax(y_test_binary, axis=1)
    mismatch = np.where(y_true != y_pred)
    cf_matrix = confusion_matrix(y_true, y_pred)
    accuracy = accuracy_score(y_true, y_pred)
    #micro_f1 = f1_score(y_true, y_pred, average='micro')
    macro_f1 = f1_score(y_true, y_pred, average='macro')
    
    c_test = c_test.flatten()
    c_pred = pred[0]
    c_pred = c_pred.flatten()
    c_pred[c_pred <= 0.5] = 0
    c_pred[c_pred > 0.5] = 1
    cf_concepts = confusion_matrix(c_test, c_pred)
    accuracy_concepts = accuracy_score(c_test, c_pred)
    
    return cf_matrix, accuracy, macro_f1, mismatch, y_pred, cf_concepts, accuracy_concepts


def get_attention(model, layer_name, input_data):

    intermediate_layer_model = Model(inputs=model.input,
                                     outputs=model.get_layer(layer_name).output)
    intermediate_output = intermediate_layer_model.predict(input_data)

    return intermediate_output

def visualize_concepts(test_input, model, concepts_text):
    
    test_input = np.expand_dims(test_input, axis=0)
    pred = model.predict(test_input)
    pred_class = inv_class_dict[np.argmax(pred[1],axis=1)[0]]
    pred_concepts = np.where(pred[0]>=0.5)
    print(f'Predicted Concepts: {pred_concepts[1]}')
    print(f'Predicted Class: {pred_class}')
    attention = np.squeeze(get_attention(model, 'attn_score', test_input))
    pred_attn = attention[pred_concepts[1]]
    pred_text = concepts_text['text'].iloc[pred_concepts[1]]
    
#     plt.rcdefaults()
    plt.style.use('seaborn-whitegrid')
    plt.rcParams.update({'font.size': 18})
    fig, ax = plt.subplots(figsize=(6,5))

    y_pos = np.arange(len(pred_text))
    ax.barh(y_pos, pred_attn, align='center')
    ax.set_yticks(y_pos)
    ax.set_yticklabels(pred_text, fontsize=16)
    ax.invert_yaxis()  # labels read top-to-bottom
    ax.set_xlabel('Concept Score', fontsize=20)
    ax.set_title(f'Predicted Activity: {pred_class}', fontsize=20)
    plt.show()
