import os,sys
import numpy as np
import pickle
import tensorflow as tf
import time
import copy as cp
from collections import Counter
# from tensorflow import keras
import tensorflow.keras as K
import matplotlib.pyplot as plt
import logging
if __name__ == "__main__":
    sys.path.insert(0, os.path.join(os.getcwd(), os.pardir))
    # sys.path.insert(0, os.path.join(os.getcwd()))
    from params import model_params
    import utils
    from utils import DotDict
    from model.EEG_Pretrain_Sleep_transformer.naive_transformer import naive_transformer


def _parse_function(example_proto, data_type):
    feature = {
        'category_name': tf.io.FixedLenFeature((), tf.string),
        'feature_N2N3_audio': tf.io.FixedLenFeature((), tf.string),
        'feature_id': tf.io.FixedLenFeature((), tf.string),
        'feature_audio': tf.io.FixedLenFeature((), tf.string),
        'feature_image_last_hidden_state': tf.io.FixedLenFeature((), tf.string),
        'feature_image_pooler_output': tf.io.FixedLenFeature((), tf.string)
                }
    parsed_features = tf.io.parse_single_example(example_proto, feature)

    category_name = parsed_features['category_name']
    N2N3_audio = parsed_features['feature_N2N3_audio']
    id = parsed_features['feature_id']
    audio = parsed_features['feature_audio']
    image_last_hidden_state = parsed_features['feature_image_last_hidden_state']
    image_pooler_output = parsed_features['feature_image_pooler_output']
    category_name = tf.io.parse_tensor(category_name, tf.float32)
    N2N3_audio = tf.io.parse_tensor(N2N3_audio, tf.float32)
    id = tf.io.parse_tensor(id, tf.float32)
    audio = tf.io.parse_tensor(audio, tf.float32)
    image_last_hidden_state = tf.io.parse_tensor(image_last_hidden_state, tf.float32)
    image_pooler_output = tf.io.parse_tensor(image_pooler_output, tf.float32)

    N2N3_audio = tf.transpose(tf.reshape(N2N3_audio, shape = [55, 80]), perm = [1,0])

    id = tf.reshape(id, shape=[68])
    audio = tf.reshape(audio, shape = [25, 1024])
    category_name = tf.reshape(category_name, shape = [15])
    image_last_hidden_state = tf.reshape(image_last_hidden_state, [197, 768])
    image_last_hidden_state = tf.squeeze(tf.reduce_mean(image_last_hidden_state[1:,:], axis = 0))
    assert image_last_hidden_state.shape == [768]
    image_pooler_output = tf.reshape(image_pooler_output, [1,768])

    return N2N3_audio , id, audio, image_last_hidden_state, image_pooler_output, category_name

def get_dataset_from_tfrecords(record_path, batch_size, shuffle_buffer_size, cycle_length, data_type):
    files = tf.data.Dataset.list_files(record_path, shuffle=True)
    dataset = files.interleave(
        map_func=tf.data.TFRecordDataset, cycle_length = cycle_length)
    dataset = dataset.map(lambda x: _parse_function(x, data_type))
    # dataset = dataset.map(_parse_function)
    dataset = dataset.shuffle(buffer_size=shuffle_buffer_size)
    dataset = dataset.batch(batch_size=batch_size)
    dataset = dataset.prefetch(batch_size * 4)
    return dataset



# @tf.function
def train_step(model_decode_meg, train_data , category_id, train_wav, train_subjectid, optimizer):

    with tf.GradientTape() as tape:
        loss_value, Result_Matrix \
            = model_decode_meg((train_data , category_id, train_wav, train_subjectid), training=True)

    grads = tape.gradient(loss_value, model_decode_meg.trainable_variables)

    optimizer.apply_gradients(zip(grads, model_decode_meg.trainable_variables))
    return loss_value, Result_Matrix

# @tf.function
def validation_step(model_decode_meg, validation_data , category_id, validation_wav,validation_subjectid):


    loss_value, Result_Matrix =\
          model_decode_meg((validation_data , category_id, validation_wav,validation_subjectid), training=False)

    return loss_value, Result_Matrix
# @tf.function
def test_step(model_test, test_data , category_id, test_wav,  test_subjectid):

    loss_value, Result_Matrix = \
          model_test((test_data , category_id, test_wav, test_subjectid), training = False)

    return loss_value, Result_Matrix




if __name__ == "__main__":
    sys.path.insert(0, os.path.join(os.getcwd(), os.pardir))
    batch_size = 256
    shuffle_buffer_size = batch_size ; cycle_length = 1
    ensemble_models = 10
    loss_mode = 'CLIP'
    type_list = ['origin', 'simple']
    data_type = type_list[0]
    state_list = ['REM', 'N2N3']
    state = state_list[1]
    individual_models = []
    for i in range(ensemble_models):
        model_item = naive_transformer()
        individual_models.append(model_item)
    basepath = os.path.join(os.getcwd(), os.pardir)
    recordpath = os.path.join(basepath, 'TFRecord')
    checkpath = os.path.join(basepath, 'checkpoint')
    log_format = "%(asctime)s - %(message)s"
    log_file_path = os.path.join(basepath, 'logger/EEG_Pretrain_Sleep_transformer.log')
    logging.basicConfig(filename=log_file_path, format=log_format, level=logging.INFO)
    path_record = os.path.join(recordpath, "TFRecord_EEG_BJ_40/sleep/"+str(state)+'_'+str(data_type))
    path_result = os.path.join(basepath, 'EEG_Result/')

    path_train0 = [os.path.join(path_record,path_i) for path_i in os.listdir(path_record)\
            if path_i.startswith("train") ]; path_train0.sort()

    path_val = [os.path.join(path_record,path_i) for path_i in os.listdir(path_record)\
            if path_i.startswith("val") ]; path_val.sort()

    path_test = [os.path.join(path_record,path_i) for path_i in os.listdir(path_record)\
            if path_i.startswith("test") ]; path_test.sort()
    
    path_train = path_train0 + path_val + path_test

    optimizer = K.optimizers.Adam(learning_rate=5e-5)

    epochs = 500
    train_loss = []
    val_loss = []
    test_loss = []
    train_accuracy = []
    val_accuracy = []
    test_accuracy = []
    no_improvement_epoches = 0
    tf.config.run_functions_eagerly(False)
    for epoch in range(epochs):
        msg = "\nStart of epoch %d" % (epoch,)
        print(msg) ; logging.info(msg)
        start_time = time.time()
        
        # Train
        train_losses = 0.0
        Accuracy_train = []
        train_dataset = get_dataset_from_tfrecords(record_path = path_train, batch_size = batch_size,\
                    shuffle_buffer_size = shuffle_buffer_size, cycle_length = cycle_length, data_type= data_type)
        
        for step, input_data in enumerate(train_dataset):
            Loss_value = [] ; Result_Matrix = []
            for model in individual_models:
                loss_value, Result_Matrix1 = train_step( \
                    model, input_data[0] ,input_data[5], input_data[2], input_data[1], optimizer)
                Loss_value.append(loss_value) ; Result_Matrix.append(Result_Matrix1)
            loss_value = tf.reduce_mean(Loss_value, axis = 0)
            Result_Matrix1 = tf.reduce_mean(Result_Matrix, axis = 0)
            loss_value, Result_Matrix1 = loss_value.numpy(), Result_Matrix1.numpy()
            accuracy_train_1 = np.argmax(Result_Matrix1, axis=-1) == np.argmax(input_data[5], axis=-1)
            accuracy_train_1 = np.sum(accuracy_train_1) / accuracy_train_1.size
            Accuracy_train.append(accuracy_train_1)
            train_losses += loss_value
        train_losses = train_losses / len(Accuracy_train)
        Accuracy_train = sum(Accuracy_train) / len(Accuracy_train)
        
        msg = 'Training Step Finished.'
        print(msg) ; logging.info(msg)

        if np.isnan(float(train_losses)):
            break

        msg = "Training Loss over epoch: %.4f" % (float(train_losses),)
        print(msg) ; logging.info(msg)
        msg = "Training Accuracy over epoch: %.2f" % (float(Accuracy_train*100),)
        print(msg) ; logging.info(msg)

        msg = "Time taken: %.2fs" % (time.time() - start_time)
        print(msg) ; logging.info(msg)
        
        if (epoch+1) % 2 ==0:
                for i in range(len(individual_models)):
                    checkpoint_path = os.path.join(checkpath, "checkpoint_EEG_pretrain_sleep_bj/transformer_new_epoch{:02d}_model{:02d}.ckpt").format(epoch, i)
                    checkpoint_dir = os.path.dirname(checkpoint_path)
                    cp_callback = individual_models[i].save_weights(checkpoint_path)