import os,sys
import numpy as np
import pickle
import tensorflow as tf
import time
import copy as cp
from collections import Counter
import tensorflow.keras as K
import matplotlib.pyplot as plt
import logging
if __name__ == "__main__":
    sys.path.insert(0, os.path.join(os.getcwd(), os.pardir))
    from params import model_params
    import utils
    from utils import DotDict
    from model.EEG_Finetune_Awake_Sleep_transformer.naive_transformer import naive_transformer


def _parse_function(example_proto, data_type):
    feature = {
        'category_name': tf.io.FixedLenFeature((), tf.string),
        'feature_N2N3_audio': tf.io.FixedLenFeature((), tf.string),
        'feature_N2N3_audio_t': tf.io.FixedLenFeature((), tf.string),
        'feature_id': tf.io.FixedLenFeature((), tf.string),
        'feature_audio': tf.io.FixedLenFeature((), tf.string),
        'feature_image_last_hidden_state': tf.io.FixedLenFeature((), tf.string),
        'feature_image_pooler_output': tf.io.FixedLenFeature((), tf.string)
                }
    parsed_features = tf.io.parse_single_example(example_proto, feature)

    category_name = parsed_features['category_name']
    N2N3_audio = parsed_features['feature_N2N3_audio']
    N2N3_audio_t = parsed_features['feature_N2N3_audio_t']
    id = parsed_features['feature_id']
    audio = parsed_features['feature_audio']
    image_last_hidden_state = parsed_features['feature_image_last_hidden_state']
    image_pooler_output = parsed_features['feature_image_pooler_output']
    category_name = tf.io.parse_tensor(category_name, tf.float32)
    N2N3_audio = tf.io.parse_tensor(N2N3_audio, tf.float32)
    N2N3_audio_t = tf.io.parse_tensor(N2N3_audio_t, tf.float32)
    id = tf.io.parse_tensor(id, tf.float32)
    audio = tf.io.parse_tensor(audio, tf.float32)
    image_last_hidden_state = tf.io.parse_tensor(image_last_hidden_state, tf.float32)
    image_pooler_output = tf.io.parse_tensor(image_pooler_output, tf.float32)

    N2N3_audio = tf.transpose(tf.reshape(N2N3_audio, shape = [55, 80]), perm = [1,0])
    N2N3_audio_t = tf.transpose(tf.reshape(N2N3_audio_t, shape = [55, 80]), perm = [1,0])

    id = tf.reshape(id, shape=[33])
    audio = tf.reshape(audio, shape = [25, 1024])
    category_name = tf.reshape(category_name, shape = [15])
    image_last_hidden_state = tf.reshape(image_last_hidden_state, [197, 768])
    image_last_hidden_state = tf.squeeze(tf.reduce_mean(image_last_hidden_state[1:,:], axis = 0))
    assert image_last_hidden_state.shape == [768]
    image_pooler_output = tf.reshape(image_pooler_output, [1,768])
    return N2N3_audio , id, audio, image_last_hidden_state, image_pooler_output, category_name, N2N3_audio_t




def get_dataset_from_tfrecords(record_path, batch_size, shuffle_buffer_size, cycle_length, data_type):
    files = tf.data.Dataset.list_files(record_path, shuffle=True)
    dataset = files.interleave(
        map_func=tf.data.TFRecordDataset, cycle_length = cycle_length)
    dataset = dataset.map(lambda x: _parse_function(x, data_type))
    # dataset = dataset.map(_parse_function)
    dataset = dataset.shuffle(buffer_size=shuffle_buffer_size)
    dataset = dataset.batch(batch_size=batch_size)
    dataset = dataset.prefetch(batch_size)
    return dataset



# @tf.function
def train_step(model_decode_meg, image_data ,audio_data, sleep_data, true_label, subjectid, optimizer):

    # train_wav = tf.cast(train_wav, dtype = tf.float32)

    with tf.GradientTape() as tape:
        loss_value, Result_Matrix, Result_Matrix2 \
            = model_decode_meg((image_data ,audio_data, sleep_data, true_label, subjectid), training=True)

    grads = tape.gradient(loss_value, model_decode_meg.trainable_variables)

    optimizer.apply_gradients(zip(grads, model_decode_meg.trainable_variables))
    return loss_value, Result_Matrix, Result_Matrix2

# @tf.function
def validation_step(model_decode_meg, image_data ,audio_data, sleep_data, true_label, subjectid):

    # validation_wav = tf.cast(validation_wav, dtype = tf.float32)

    loss_value, Result_Matrix, Result_Matrix2 =\
          model_decode_meg((image_data ,audio_data, sleep_data, true_label, subjectid), training=False)

    return loss_value, Result_Matrix, Result_Matrix2
# @tf.function
def test_step(model, image_data ,audio_data, sleep_data, true_label, subjectid):
    # global model_test

    # test_wav = tf.cast(test_wav, dtype = tf.float32)

    loss_value, Result_Matrix,Result_Matrix2 = \
          model((image_data ,audio_data, sleep_data, true_label, subjectid), training = False)

    return loss_value, Result_Matrix, Result_Matrix2


def load_data(path, batch_size, data_type):
    file = open(path,'rb')
    dataset = pickle.load(file)
    category, awake_image, awake_audio, N2N3_audio, subject_id = [],[],[],[],[]
    for _ in range(len(dataset)):
        category.append(dataset[_]['category'])
        awake_image.append(dataset[_]['awake_image'])
        awake_audio.append(dataset[_]['awake_audio'])
        N2N3_audio.append(dataset[_]['N2N3_audio'])
        subject_id.append(dataset[_]['subjectid'])

    category = np.asarray(category) ; awake_image = np.asarray(awake_image)
    N2N3_audio = np.asarray(N2N3_audio) ; subject_id = np.asarray(subject_id)
    awake_image = np.transpose(awake_image, (0,1,3,2))
    awake_audio = np.transpose(awake_audio, (0,1,3,2))
    N2N3_audio = np.transpose(N2N3_audio, (0,2,1))
    dataset_length = len(category)
    dataset_ = tf.data.Dataset.from_tensor_slices((awake_image, awake_audio, N2N3_audio, category, subject_id))
    # Shuffle and then batch the dataset.
    Dataset = dataset_.shuffle(dataset_length).batch(
        batch_size, drop_remainder=False).prefetch(2)

    return Dataset

def load_whole_data(path, batch_size, ratio, data_type):
    file = open(path,'rb')
    dataset = pickle.load(file)
    category, awake_image, awake_audio, N2N3_audio, subject_id = [],[],[],[],[]
    for _ in range(len(dataset)):
        category.append(dataset[_]['category'])
        awake_image.append(dataset[_]['awake_image'])
        awake_audio.append(dataset[_]['awake_audio'])
        N2N3_audio.append(dataset[_]['N2N3_audio'])
        subject_id.append(dataset[_]['subjectid'])
    train_length = int(len(dataset) * ratio)
    val_length = int(len(dataset) * ((1-ratio)/2))
    test_length = len(dataset) - train_length - val_length
    category = np.asarray(category) ; awake_image = np.asarray(awake_image)
    N2N3_audio = np.asarray(N2N3_audio) ; subject_id = np.asarray(subject_id)
    awake_image = np.transpose(awake_image, (0,1,3,2))
    awake_audio = np.transpose(awake_audio, (0,1,3,2))
    N2N3_audio = np.transpose(N2N3_audio, (0,2,1))
    dataset_length = len(category)
    val_index = train_length + val_length
    train_dataset_ = tf.data.Dataset.from_tensor_slices((awake_image[:train_length], awake_audio[:train_length], \
                    N2N3_audio[:train_length],category[:train_length], subject_id[:train_length]))
    val_dataset_ = tf.data.Dataset.from_tensor_slices((awake_image[train_length:val_index], awake_audio[train_length:val_index],\
                    N2N3_audio[train_length:val_index], category[train_length:val_index], subject_id[train_length:val_index]))
    test_dataset_ = tf.data.Dataset.from_tensor_slices((awake_image[val_index:], awake_audio[val_index:], \
                    N2N3_audio[val_index:],category[val_index:], subject_id[val_index:]))
    # Shuffle and then batch the dataset.
    train_Dataset = train_dataset_.shuffle(train_length).batch(
        batch_size, drop_remainder=False).prefetch(2)
    val_Dataset = val_dataset_.shuffle(train_length).batch(
        batch_size, drop_remainder=False).prefetch(2)
    test_Dataset = test_dataset_.shuffle(train_length).batch(
        batch_size, drop_remainder=False).prefetch(2)
    
    return train_Dataset, val_Dataset, test_Dataset

def load_test_data(path, batch_size):
    file = open(path,'rb')
    dataset = pickle.load(file)
    category, N2N3_audio, subject_id = [],[],[]
    for _ in range(len(dataset)):
        category.append(dataset[_]['category'])
        # awake_image.append(train_dataset[_]['awake_image'])
        N2N3_audio.append(dataset[_]['N2N3_audio'])
        subject_id.append(dataset[_]['subjectid'])

    category = np.asarray(category) 
    N2N3_audio = np.asarray(N2N3_audio) ; subject_id = np.asarray(subject_id)
    # awake_image = np.transpose(awake_image, (0,2,1))
    N2N3_audio = np.transpose(N2N3_audio, (0,2,1))
    dataset_length = len(category)
    dataset_ = tf.data.Dataset.from_tensor_slices((N2N3_audio, category, subject_id))
    # Shuffle and then batch the dataset.
    Dataset = dataset_.shuffle(dataset_length).batch(
        batch_size, drop_remainder=False).prefetch(2)

    return Dataset

if __name__ == "__main__":
    sys.path.insert(0, os.path.join(os.getcwd(), os.pardir))
    batch_size = 64; seq_len = 100; n_channels = 55; n_features = 768
    shuffle_buffer_size = batch_size ; cycle_length = 1
    ensemble_models = 15
    state_list = ['REM', 'N2N3']
    state = state_list[1]
    ratio_list = [0.2, 0.4, 0.6, 0.8]
    basepath = os.path.join(os.getcwd(), os.pardir)
    recordpath = os.path.join(basepath, 'TFRecord')
    checkpath = os.path.join(basepath, 'checkpoint')
    # Initialize training process.
    utils.model.set_seeds(1642)

    log_format = "%(asctime)s - %(message)s"
    log_file_path = os.path.join(basepath, 'logger/EEG_Finetune_Awake_Sleep_transformer.log')
    logging.basicConfig(filename=log_file_path, format=log_format, level=logging.INFO)
    for ratio in ratio_list:
        individual_models = []
        for i in range(ensemble_models):
            model_item = naive_transformer()
            checkpoint_path = os.path.join(checkpath, "checkpoint_EEG_pretrain_awake_sleep_bj_single/origin_model{:02d}.ckpt").format(i)
            model_item.load_weights(checkpoint_path)
            individual_models.append(model_item)

        path_record = os.path.join(recordpath, "TFRecord_EEG_SZ")
        path_result = os.path.join(basepath, 'EEG_Result/')

        path_whole = [os.path.join(path_record,path_i) for path_i in os.listdir(path_record)\
                        if path_i.startswith("whole_data") ]; path_whole.sort()
        path_whole = path_whole[0]
        train_dataset, validation_dataset, test_dataset = load_whole_data(path_whole, batch_size, ratio)



        optimizer = K.optimizers.Adam(learning_rate=1e-5)

        # training config
        epochs = 100
        train_loss = []
        val_loss = []
        test_loss = []
        train_accuracy = []
        val_accuracy = []
        test_accuracy = []
        train_accuracy2 = []
        val_accuracy2 = []
        test_accuracy2 = []
        train_accuracy3 = []
        val_accuracy3 = []
        test_accuracy3 = []
        no_improvement_epoches = 0
        tf.config.run_functions_eagerly(False)

        for epoch in range(epochs):
            print("\nStart of epoch %d" % (epoch,))
            start_time = time.time()
            
            # Train
            train_losses = 0.0
            Accuracy_train = []
            Accuracy_train2 = []
            Accuracy_train3 = []
            
            for step, input_data in enumerate(train_dataset):
                Loss_value = [] ; Result_Matrix1 = []
                Result_Matrix_2 = []
                raw_index = np.arange(0, 150)
                new_index = np.random.choice(raw_index, ensemble_models, replace=False)
                awake_image = tf.gather(input_data[0], new_index, axis = 1)
                awake_audio = tf.gather(input_data[1], new_index, axis = 1)
                assert awake_audio.shape[1:] == [ensemble_models, 80, 55]
                for _ in range(ensemble_models):
                    model = individual_models[_]
                    image_data = tf.squeeze(awake_image[:,_,:,:])
                    audio_data = tf.squeeze(awake_audio[:,_,:,:])
                    loss_value, Result_Matrix, Result_Matrix2= train_step( \
                        model, image_data ,audio_data, input_data[2], input_data[3], input_data[4], optimizer)
                    Loss_value.append(loss_value) ; Result_Matrix1.append(Result_Matrix)
                    Result_Matrix_2.append(Result_Matrix2)
                loss_value = tf.reduce_mean(Loss_value, axis = 0)
                Result_Matrix = tf.reduce_mean(Result_Matrix1, axis = 0)
                Result_Matrix2 = tf.reduce_mean(Result_Matrix_2, axis = 0)
                Result_Matrix3 = tf.nn.softmax(Result_Matrix2 + Result_Matrix1, axis = -1)
                Result_Matrix = Result_Matrix.numpy() ; Result_Matrix2 = Result_Matrix2.numpy()
                accuracy_train_1 = np.argmax(Result_Matrix, axis=-1) == np.argmax(input_data[3], axis=-1)
                accuracy_train_1 = np.sum(accuracy_train_1) / accuracy_train_1.size
                Accuracy_train.append(accuracy_train_1)
                accuracy_train_2 = np.argmax(Result_Matrix2, axis=-1) == np.argmax(input_data[3], axis=-1)
                accuracy_train_2 = np.sum(accuracy_train_2) / accuracy_train_2.size
                Accuracy_train2.append(accuracy_train_2)
                accuracy_train_3 = np.argmax(Result_Matrix3, axis=-1) == np.argmax(input_data[3], axis=-1)
                accuracy_train_3 = np.sum(accuracy_train_3) / accuracy_train_3.size
                Accuracy_train3.append(accuracy_train_3)
                train_losses += loss_value
            train_losses = train_losses / len(Accuracy_train)
            Accuracy_train = sum(Accuracy_train) / len(Accuracy_train)
            Accuracy_train2 = sum(Accuracy_train2) / len(Accuracy_train2)
            Accuracy_train3= sum(Accuracy_train3) / len(Accuracy_train3)
            
            msg = 'Training Step Finished.'
            print(msg) ; logging.info(msg)

            val_losses = 0.0
        
            Accuracy_val = []
            Accuracy_val2 = []
            Accuracy_val3 = []

            for step, input_data in enumerate(validation_dataset): 
                Loss_value = [] ; Result_Matrix1 = []
                Result_Matrix_2 = []
                raw_index = np.arange(0, 150)
                new_index = np.random.choice(raw_index, ensemble_models, replace=False)
                awake_image = tf.gather(input_data[0], new_index, axis = 1)
                awake_audio = tf.gather(input_data[1], new_index, axis = 1)
                for _ in range(ensemble_models):
                    model = individual_models[_]
                    image_data = tf.squeeze(awake_image[:,_,:,:])
                    audio_data = tf.squeeze(awake_audio[:,_,:,:])
                    loss_value, Result_Matrix,Result_Matrix2 = validation_step( \
                        model, image_data ,audio_data, input_data[2], input_data[3], input_data[4])
                    Loss_value.append(loss_value) ; Result_Matrix1.append(Result_Matrix)
                    Result_Matrix_2.append(Result_Matrix2)
                loss_value = tf.reduce_mean(Loss_value, axis = 0)
                Result_Matrix = tf.reduce_mean(Result_Matrix1, axis = 0)
                Result_Matrix2 = tf.reduce_mean(Result_Matrix_2, axis = 0)
                Result_Matrix3 = tf.nn.softmax(Result_Matrix2 + Result_Matrix1, axis = -1)
                Result_Matrix = Result_Matrix.numpy() ; Result_Matrix2 = Result_Matrix2.numpy()
                accuracy_val_1 = np.argmax(Result_Matrix, axis=-1) == np.argmax(input_data[3], axis=-1)
                accuracy_val_1 = np.sum(accuracy_val_1) / accuracy_val_1.size
                Accuracy_val.append(accuracy_val_1)
                accuracy_val_2 = np.argmax(Result_Matrix2, axis=-1) == np.argmax(input_data[3], axis=-1)
                accuracy_val_2 = np.sum(accuracy_val_2) / accuracy_val_2.size
                Accuracy_val2.append(accuracy_val_2)
                accuracy_val_3 = np.argmax(Result_Matrix3, axis=-1) == np.argmax(input_data[3], axis=-1)
                accuracy_val_3 = np.sum(accuracy_val_3) / accuracy_val_3.size
                Accuracy_val3.append(accuracy_val_3)
                val_losses += loss_value
            val_losses = val_losses / len(Accuracy_val)
            Accuracy_val2 = sum(Accuracy_val2) / len(Accuracy_val2)
            Accuracy_val3 = sum(Accuracy_val3) / len(Accuracy_val3)
            Accuracy_val = sum(Accuracy_val) / len(Accuracy_val)
            msg = 'Validation Step Finished.'
            print(msg) ; logging.info(msg)

            test_losses = 0.0
            Top_accuracy, Top_10_accuracy = [], []
            Accuracy_test= []
            Accuracy_test2= []
            Accuracy_test3= []
            
            for step, input_data in enumerate(test_dataset): 
                Loss_value = [] ; Result_Matrix1 = [] 
                Result_Matrix_2 = []
                raw_index = np.arange(0, 150)
                new_index = np.random.choice(raw_index, ensemble_models, replace=False)
                awake_image = tf.gather(input_data[0], new_index, axis = 1)
                awake_audio = tf.gather(input_data[1], new_index, axis = 1)
                for _ in range(ensemble_models):
                    model = individual_models[_]
                    image_data = tf.squeeze(awake_image[:,_,:,:])
                    audio_data = tf.squeeze(awake_audio[:,_,:,:])
                    loss_value, Result_Matrix,Result_Matrix2 = test_step( \
                        model, image_data ,audio_data, input_data[2], input_data[3], input_data[4])
                    Loss_value.append(loss_value) ; Result_Matrix1.append(Result_Matrix)
                    Result_Matrix_2.append(Result_Matrix2)
                loss_value = tf.reduce_mean(Loss_value, axis = 0)
                Result_Matrix = tf.reduce_mean(Result_Matrix1, axis = 0)
                Result_Matrix2 = tf.reduce_mean(Result_Matrix_2, axis = 0)
                Result_Matrix3 = tf.nn.softmax(Result_Matrix2 + Result_Matrix1, axis = -1)
                Result_Matrix = Result_Matrix.numpy() ; Result_Matrix2 = Result_Matrix2.numpy()
                accuracy_test_1 = np.argmax(Result_Matrix, axis=-1) == np.argmax(input_data[3], axis=-1)
                accuracy_test_1 = np.sum(accuracy_test_1) / accuracy_test_1.size
                Accuracy_test.append(accuracy_test_1)
                accuracy_test_2 = np.argmax(Result_Matrix2, axis=-1) == np.argmax(input_data[3], axis=-1)
                accuracy_test_2 = np.sum(accuracy_test_2) / accuracy_test_2.size
                Accuracy_test2.append(accuracy_test_2)
                accuracy_test_3 = np.argmax(Result_Matrix3, axis=-1) == np.argmax(input_data[3], axis=-1)
                accuracy_test_3 = np.sum(accuracy_test_3) / accuracy_test_3.size
                Accuracy_test3.append(accuracy_test_3)
                test_losses += loss_value

            test_losses = test_losses / len(Accuracy_test)
            Accuracy_test = sum(Accuracy_test) / len(Accuracy_test)
            Accuracy_test2 = sum(Accuracy_test2) / len(Accuracy_test2)
            Accuracy_test3 = sum(Accuracy_test3) / len(Accuracy_test3)

            msg = 'Test Step Finished.'
            print(msg) ; logging.info(msg)
            msg = "Training Loss over epoch: %.4f" % (float(train_losses),)
            print(msg) ; logging.info(msg)
            msg = "Training Accuracy over epoch: %.2f" % (float(Accuracy_train*100),)
            print(msg) ; logging.info(msg)
            msg = "Training Accuracy2 over epoch: %.2f" % (float(Accuracy_train2*100),)
            print(msg) ; logging.info(msg)
            msg = "Training Accuracy3 over epoch: %.2f" % (float(Accuracy_train3*100),)
            print(msg) ; logging.info(msg)

            msg = "Validation Loss: %.4f" % (float(val_losses),)
            print(msg) ; logging.info(msg)
            msg = "Validation Accuracy over epoch: %.2f" % (float(Accuracy_val*100),)
            print(msg) ; logging.info(msg)
            msg = "Validation Accuracy2 over epoch: %.2f" % (float(Accuracy_val2*100),)
            print(msg) ; logging.info(msg)
            msg = "Validation Accuracy3 over epoch: %.2f" % (float(Accuracy_val3*100),)
            print(msg) ; logging.info(msg)

            msg = "Test Loss: %.4f" % (float(test_losses),)
            print(msg) ; logging.info(msg)
            msg = "Test Accuracy over epoch: %.2f" % (float(Accuracy_test*100),)
            print(msg) ; logging.info(msg)
            msg = "Test Accuracy2 over epoch: %.2f" % (float(Accuracy_test2*100),)
            print(msg) ; logging.info(msg)
            msg = "Test Accuracy3 over epoch: %.2f" % (float(Accuracy_test3*100),)
            print(msg) ; logging.info(msg)

            msg = "Time taken: %.2fs" % (time.time() - start_time)
            print(msg) ; logging.info(msg)


            
            val_loss.append(float(val_losses))
            val_accuracy.append(float(Accuracy_val*100))
            val_accuracy2.append(float(Accuracy_val2*100))
            val_accuracy3.append(float(Accuracy_val3*100))
            test_loss.append(float(test_losses))
            test_accuracy.append(float(Accuracy_test*100))
            test_accuracy2.append(float(Accuracy_test2*100))
            test_accuracy3.append(float(Accuracy_test3*100))
        time_now = time.strftime("%Y%m%d-%H%M", time.localtime())
        val_accuracy = np.round(np.array(val_accuracy, dtype=np.float32), decimals=4)
        test_accuracy = np.round(np.array(test_accuracy, dtype=np.float32), decimals=4)
        epoch_maxacc_idxs = np.where(val_accuracy == np.max(val_accuracy))[0]
        epoch_maxacc_idx = epoch_maxacc_idxs[np.argmax(test_accuracy[epoch_maxacc_idxs])]
        # Finish training process of current specified experiment.
        msg = (
            "Best test-accuracy ({:.2f}%) according to max validation-accuracy ({:.2f}%) at epoch {:d}."
        ).format(test_accuracy[epoch_maxacc_idx],
            val_accuracy[epoch_maxacc_idx], epoch_maxacc_idx)
        print(msg); logging.info(msg)

        val_accuracy2 = np.round(np.array(val_accuracy2, dtype=np.float32), decimals=4)
        test_accuracy2 = np.round(np.array(test_accuracy2, dtype=np.float32), decimals=4)
        epoch_maxacc_idxs = np.where(val_accuracy2 == np.max(val_accuracy2))[0]
        epoch_maxacc_idx = epoch_maxacc_idxs[np.argmax(test_accuracy[epoch_maxacc_idxs])]
        # Finish training process of current specified experiment.
        msg = (
            "Best test-accuracy2 ({:.2f}%) according to max validation-accuracy2 ({:.2f}%) at epoch {:d}."
        ).format(test_accuracy2[epoch_maxacc_idx],
            val_accuracy2[epoch_maxacc_idx], epoch_maxacc_idx)
        print(msg); logging.info(msg)

        val_accuracy3 = np.round(np.array(val_accuracy3, dtype=np.float32), decimals=4)
        test_accuracy3 = np.round(np.array(test_accuracy3, dtype=np.float32), decimals=4)
        epoch_maxacc_idxs = np.where(val_accuracy3 == np.max(val_accuracy3))[0]
        epoch_maxacc_idx = epoch_maxacc_idxs[np.argmax(test_accuracy[epoch_maxacc_idxs])]
        # Finish training process of current specified experiment.
        msg = (
            "Best test-accuracy3 ({:.2f}%) according to max validation-accuracy3 ({:.2f}%) at epoch {:d}."
        ).format(test_accuracy3[epoch_maxacc_idx],
            val_accuracy3[epoch_maxacc_idx], epoch_maxacc_idx)
        print(msg); logging.info(msg)