import os,sys
import numpy as np
import pickle
import tensorflow as tf
import time
import copy as cp
from collections import Counter
import tensorflow.keras as K
import matplotlib.pyplot as plt
import logging
if __name__ == "__main__":
    sys.path.insert(0, os.path.join(os.getcwd(), os.pardir))
    from params import model_params
    import utils
    from utils import DotDict
    from model.EEG_Pretrain_Awake_Sleep_transformer.naive_transformer import naive_transformer


def _parse_function(example_proto, data_size):
    feature = {
        'category_name': tf.io.FixedLenFeature((), tf.string),
        'feature_awake_image': tf.io.FixedLenFeature((), tf.string),
        'feature_awake_audio': tf.io.FixedLenFeature((), tf.string),
        'feature_N2N3_audio': tf.io.FixedLenFeature((), tf.string),
        'feature_id': tf.io.FixedLenFeature((), tf.string),
                }
    parsed_features = tf.io.parse_single_example(example_proto, feature)

    category_name = parsed_features['category_name']
    awake_image = parsed_features['feature_awake_image']
    awake_audio = parsed_features['feature_awake_audio']
    N2N3_audio = parsed_features['feature_N2N3_audio']
    id = parsed_features['feature_id']
    category_name = tf.io.parse_tensor(category_name, tf.float32)
    awake_image = tf.io.parse_tensor(awake_image, tf.float32)
    awake_audio = tf.io.parse_tensor(awake_audio, tf.float32)
    N2N3_audio = tf.io.parse_tensor(N2N3_audio, tf.float32)
    id = tf.io.parse_tensor(id, tf.float32)
    awake_image = tf.transpose(tf.reshape(awake_image, shape = [55, 80]), perm = [1,0])
    awake_audio = tf.transpose(tf.reshape(awake_audio, shape = [55, 80]), perm = [1,0])
    N2N3_audio = tf.transpose(tf.reshape(N2N3_audio, shape = [55, 80]), perm = [1,0])


    id = tf.reshape(id, shape=[70])
    category_name = tf.reshape(category_name, shape = [15])

    return awake_image, awake_audio, N2N3_audio , category_name, id

def get_dataset_from_tfrecords(record_path, batch_size, shuffle_buffer_size, cycle_length, data_size):
    files = tf.data.Dataset.list_files(record_path, shuffle=True)
    dataset = files.interleave(
        map_func=tf.data.TFRecordDataset, cycle_length = cycle_length)
    dataset = dataset.map(lambda x: _parse_function(x, data_size))
    dataset = dataset.shuffle(buffer_size=shuffle_buffer_size)
    dataset = dataset.batch(batch_size=batch_size)
    dataset = dataset.prefetch(batch_size)
    return dataset



# @tf.function
def train_step(model_decode_meg, image_data ,audio_data, sleep_data, true_label, subjectid, optimizer):

    with tf.GradientTape() as tape:
        loss_value, Result_Matrix, Result_Matrix2 \
            = model_decode_meg((image_data ,audio_data, sleep_data, true_label, subjectid), training=True)

    grads = tape.gradient(loss_value, model_decode_meg.trainable_variables)

    optimizer.apply_gradients(zip(grads, model_decode_meg.trainable_variables))
    return loss_value, Result_Matrix, Result_Matrix2

# @tf.function
def validation_step(model_decode_meg, image_data ,audio_data, sleep_data, true_label, subjectid):

    loss_value, Result_Matrix, Result_Matrix2 =\
          model_decode_meg((image_data ,audio_data, sleep_data, true_label, subjectid), training=False)

    return loss_value, Result_Matrix, Result_Matrix2
# @tf.function
def test_step(model, image_data ,audio_data, sleep_data, true_label, subjectid):

    loss_value, Result_Matrix, Result_Matrix2 = \
          model((image_data ,audio_data, sleep_data, true_label, subjectid), training = False)

    return loss_value, Result_Matrix, Result_Matrix2


def load_whole_data(path_train, path_val,path_test, batch_size):
    train_file = open(path_train,'rb')
    train_dataset = pickle.load(train_file)
    val_file = open(path_val,'rb')
    val_dataset = pickle.load(val_file)
    test_file = open(path_test,'rb')
    test_dataset = pickle.load(test_file)
    category, awake_image, N2N3_audio, subject_id = [],[],[],[]
    for _ in range(len(train_dataset)):
        category.append(train_dataset[_]['category'])
        awake_image.append(train_dataset[_]['awake_image'])
        N2N3_audio.append(train_dataset[_]['N2N3_audio'])
        subject_id.append(train_dataset[_]['subjectid'])
    for _ in range(len(val_dataset)):
        category.append(val_dataset[_]['category'])
        awake_image.append(val_dataset[_]['awake_image'])
        N2N3_audio.append(val_dataset[_]['N2N3_audio'])
        subject_id.append(val_dataset[_]['subjectid'])
    for _ in range(len(test_dataset)):
        category.append(test_dataset[_]['category'])
        awake_image.append(test_dataset[_]['awake_image'])
        N2N3_audio.append(test_dataset[_]['N2N3_audio'])
        subject_id.append(test_dataset[_]['subjectid'])

    category = np.asarray(category) ; awake_image = np.asarray(awake_image)
    N2N3_audio = np.asarray(N2N3_audio) ; subject_id = np.asarray(subject_id)
    awake_image = np.transpose(awake_image, (0,2,1))
    N2N3_audio = np.transpose(N2N3_audio, (0,2,1))
    dataset_length = len(category)
    dataset_ = tf.data.Dataset.from_tensor_slices((awake_image, N2N3_audio, category, subject_id))
    Dataset = dataset_.shuffle(dataset_length).batch(
        batch_size, drop_remainder=False).prefetch(2)

    return Dataset



if __name__ == "__main__":
    sys.path.insert(0, os.path.join(os.getcwd(), os.pardir))
    batch_size = 256
    shuffle_buffer_size = batch_size ; cycle_length = 1
    ensemble_models = 15

    utils.model.set_seeds(1642)
    basepath = os.path.join(os.getcwd(), os.pardir)
    recordpath = os.path.join(basepath, 'TFRecord')
    checkpath = os.path.join(basepath, 'checkpoint')
    # Configure the logger
    log_format = "%(asctime)s - %(message)s"
    log_file_path = os.path.join(basepath, 'logger/EEG_Pretrain_Awake_Sleep_transformer.log')
    logging.basicConfig(filename=log_file_path, format=log_format, level=logging.INFO)
    individual_models = []
    for i in range(ensemble_models):
        model_item = naive_transformer()
        individual_models.append(model_item)
    
    path_record = os.path.join(recordpath, "TFRecord_EEG_BJ_40")
    path_result = os.path.join(basepath, 'EEG_Result/')

    path_train = [os.path.join(path_record,path_i) for path_i in os.listdir(path_record)\
            if path_i.startswith("train") ]; path_train.sort()
    path_val = [os.path.join(path_record,path_i) for path_i in os.listdir(path_record)\
            if path_i.startswith("val") ]; path_val.sort()

    path_test = [os.path.join(path_record,path_i) for path_i in os.listdir(path_record)\
            if path_i.startswith("test") ]; path_test.sort()
    path_train = path_train + path_val + path_test

    optimizer = K.optimizers.Adam(learning_rate=5e-5)

    # training config
    epochs = 500
    train_loss = []
    val_loss = []
    test_loss = []
    train_accuracy = []
    val_accuracy = []
    test_accuracy = []
    no_improvement_epoches = 0
    tf.config.run_functions_eagerly(False)

    for epoch in range(epochs):
        msg = "\nStart of epoch %d" % (epoch,)
        print(msg) ; logging.info(msg)
        start_time = time.time()
        
        # Train
        train_losses = 0.0
        Accuracy_train = []
        Accuracy_train2 = []
        Accuracy_train3 = []
        train_dataset = get_dataset_from_tfrecords(record_path = path_train, batch_size = batch_size,shuffle_buffer_size = shuffle_buffer_size,\
                                                cycle_length = cycle_length, data_size = ensemble_models)
        
        
        for step, input_data in enumerate(train_dataset):
            Loss_value = [] ; Result_Matrix1 = []
            Result_Matrix_2 = []
            for _ in range(ensemble_models):
                model = individual_models[_]
                image_data = input_data[0]
                audio_data = input_data[1]
                loss_value, Result_Matrix, Result_Matrix2= train_step( \
                    model, image_data ,audio_data, input_data[2], input_data[3], input_data[4], optimizer)
                Loss_value.append(loss_value) ; Result_Matrix1.append(Result_Matrix)
                Result_Matrix_2.append(Result_Matrix2)
            loss_value = tf.reduce_mean(Loss_value, axis = 0)
            Result_Matrix = tf.reduce_mean(Result_Matrix1, axis = 0)
            Result_Matrix2 = tf.reduce_mean(Result_Matrix_2, axis = 0)
            Result_Matrix3 = tf.nn.softmax(Result_Matrix2 + Result_Matrix1, axis = -1)
            Result_Matrix = Result_Matrix.numpy() ; Result_Matrix2 = Result_Matrix2.numpy()
            accuracy_train_1 = np.argmax(Result_Matrix, axis=-1) == np.argmax(input_data[3], axis=-1)
            accuracy_train_1 = np.sum(accuracy_train_1) / accuracy_train_1.size
            Accuracy_train.append(accuracy_train_1)
            accuracy_train_2 = np.argmax(Result_Matrix2, axis=-1) == np.argmax(input_data[3], axis=-1)
            accuracy_train_2 = np.sum(accuracy_train_2) / accuracy_train_2.size
            Accuracy_train2.append(accuracy_train_2)
            accuracy_train_3 = np.argmax(Result_Matrix3, axis=-1) == np.argmax(input_data[3], axis=-1)
            accuracy_train_3 = np.sum(accuracy_train_3) / accuracy_train_3.size
            Accuracy_train3.append(accuracy_train_3)
            train_losses += loss_value
        train_losses = train_losses / len(Accuracy_train)
        Accuracy_train = sum(Accuracy_train) / len(Accuracy_train)
        Accuracy_train2 = sum(Accuracy_train2) / len(Accuracy_train2)
        Accuracy_train3= sum(Accuracy_train3) / len(Accuracy_train3)
        
        msg = 'Training Step Finished.'
        print(msg) ; logging.info(msg)

        msg = "Training Loss over epoch: %.4f" % (float(train_losses),)
        print(msg) ; logging.info(msg)
        msg = "Training Accuracy over epoch: %.2f" % (float(Accuracy_train*100),)
        print(msg) ; logging.info(msg)
        msg = "Training Accuracy2 over epoch: %.2f" % (float(Accuracy_train2*100),)
        print(msg) ; logging.info(msg)
        msg = "Training Accuracy3 over epoch: %.2f" % (float(Accuracy_train3*100),)
        print(msg) ; logging.info(msg)

        msg = "Time taken: %.2fs" % (time.time() - start_time)
        print(msg) ; logging.info(msg)
        
        if (epoch+1) % 2 ==0:
                for i in range(len(individual_models)):
                    checkpoint_path = os.path.join(checkpath, "checkpoint_EEG_pretrain_awake_sleep_bj_single/distribution_origin_epoch{:02d}_model{:02d}.ckpt").format(epoch, i)
                    checkpoint_dir = os.path.dirname(checkpoint_path)
                    cp_callback = individual_models[i].save_weights(checkpoint_path)
