import os,sys
import numpy as np
import pickle
import tensorflow as tf
import time
import copy as cp
from collections import Counter
import tensorflow.keras as K
import matplotlib.pyplot as plt
import logging
if __name__ == "__main__":
    sys.path.insert(0, os.path.join(os.getcwd(), os.pardir))
    from params import model_params
    import utils
    from utils import DotDict
    from model.EEG_Finetune_Sleep_transformer.naive_transformer import naive_transformer

def _parse_function(example_proto, data_type):
    feature = {
        'category_name': tf.io.FixedLenFeature((), tf.string),
        'feature_N2N3_audio': tf.io.FixedLenFeature((), tf.string),
        'feature_N2N3_audio_t': tf.io.FixedLenFeature((), tf.string),
        'feature_id': tf.io.FixedLenFeature((), tf.string),
        'feature_audio': tf.io.FixedLenFeature((), tf.string),
        'feature_image_last_hidden_state': tf.io.FixedLenFeature((), tf.string),
        'feature_image_pooler_output': tf.io.FixedLenFeature((), tf.string)
                }
    parsed_features = tf.io.parse_single_example(example_proto, feature)

    category_name = parsed_features['category_name']
    N2N3_audio = parsed_features['feature_N2N3_audio']
    N2N3_audio_t = parsed_features['feature_N2N3_audio_t']
    id = parsed_features['feature_id']
    audio = parsed_features['feature_audio']
    image_last_hidden_state = parsed_features['feature_image_last_hidden_state']
    image_pooler_output = parsed_features['feature_image_pooler_output']
    category_name = tf.io.parse_tensor(category_name, tf.float32)
    N2N3_audio = tf.io.parse_tensor(N2N3_audio, tf.float32)
    N2N3_audio_t = tf.io.parse_tensor(N2N3_audio_t, tf.float32)
    id = tf.io.parse_tensor(id, tf.float32)
    audio = tf.io.parse_tensor(audio, tf.float32)
    image_last_hidden_state = tf.io.parse_tensor(image_last_hidden_state, tf.float32)
    image_pooler_output = tf.io.parse_tensor(image_pooler_output, tf.float32)

    N2N3_audio = tf.transpose(tf.reshape(N2N3_audio, shape = [55, 80]), perm = [1,0])
    N2N3_audio_t = tf.transpose(tf.reshape(N2N3_audio_t, shape = [55, 80]), perm = [1,0])

    id = tf.reshape(id, shape=[33])
    audio = tf.reshape(audio, shape = [25, 1024])
    category_name = tf.reshape(category_name, shape = [15])
    image_last_hidden_state = tf.reshape(image_last_hidden_state, [197, 768])
    image_last_hidden_state = tf.squeeze(tf.reduce_mean(image_last_hidden_state[1:,:], axis = 0))
    assert image_last_hidden_state.shape == [768]
    image_pooler_output = tf.reshape(image_pooler_output, [1,768])
    return N2N3_audio , category_name,audio, id, image_last_hidden_state, image_pooler_output,  N2N3_audio_t




def get_dataset_from_tfrecords(record_path, batch_size, shuffle_buffer_size, cycle_length, data_type):
    files = tf.data.Dataset.list_files(record_path, shuffle=True)
    dataset = files.interleave(
        map_func=tf.data.TFRecordDataset, cycle_length = cycle_length)
    dataset = dataset.map(lambda x: _parse_function(x, data_type))
    dataset = dataset.shuffle(buffer_size=shuffle_buffer_size)
    dataset = dataset.batch(batch_size=batch_size)
    dataset = dataset.prefetch(batch_size)
    return dataset



# @tf.function
def train_step(model_decode_meg, train_data , category_id, train_wav, train_subjectid, optimizer):

    with tf.GradientTape() as tape:
        loss_value, Result_Matrix \
            = model_decode_meg((train_data , category_id, train_wav, train_subjectid), training=True)

    grads = tape.gradient(loss_value, model_decode_meg.trainable_variables)

    optimizer.apply_gradients(zip(grads, model_decode_meg.trainable_variables))
    return loss_value, Result_Matrix

# @tf.function
def validation_step(model_decode_meg, validation_data , category_id, validation_wav,validation_subjectid):

    loss_value, Result_Matrix =\
          model_decode_meg((validation_data , category_id, validation_wav,validation_subjectid), training=False)

    return loss_value, Result_Matrix
# @tf.function
def test_step(model_test, test_data , category_id, test_wav,  test_subjectid):

    loss_value, Result_Matrix = \
          model_test((test_data , category_id, test_wav, test_subjectid), training = False)

    return loss_value, Result_Matrix


def load_data(path, batch_size):
    file = open(path,'rb')
    dataset = pickle.load(file)
    category, N2N3_audio, audio, subject_id = [],[],[],[]
    for _ in range(len(dataset)):
        category.append(dataset[_]['category'])
        audio.append(dataset[_]['audio'])
        N2N3_audio.append(dataset[_]['N2N3_audio'])
        subject_id.append(dataset[_]['subjectid'])

    category = np.asarray(category) ; N2N3_audio = np.asarray(N2N3_audio)
    audio = np.asarray(audio) ; subject_id = np.asarray(subject_id)
    N2N3_audio = np.transpose(N2N3_audio, (0,2,1))
    dataset_length = len(category)
    dataset_ = tf.data.Dataset.from_tensor_slices((N2N3_audio, category, audio, subject_id))
    # Shuffle and then batch the dataset.
    Dataset = dataset_.shuffle(dataset_length).batch(
        batch_size, drop_remainder=False).prefetch(2)

    return Dataset

def load_train_data(path, batch_size, ratio):
    file = open(path,'rb')
    dataset = pickle.load(file)
    rate = ratio / 0.8
    n_samples = int(len(dataset) * rate)
    category, N2N3_audio, audio, subject_id = [],[],[],[]
    for _ in range(len(dataset)):
        category.append(dataset[_]['category'])
        audio.append(dataset[_]['audio'])
        N2N3_audio.append(dataset[_]['N2N3_audio'])
        subject_id.append(dataset[_]['subjectid'])

    category = np.asarray(category) ; N2N3_audio = np.asarray(N2N3_audio)
    audio = np.asarray(audio) ; subject_id = np.asarray(subject_id)
    N2N3_audio = np.transpose(N2N3_audio, (0,2,1))
    dataset_length = len(category)
    dataset_ = tf.data.Dataset.from_tensor_slices((N2N3_audio[:n_samples], category[:n_samples], audio[:n_samples], subject_id[:n_samples]))
    # Shuffle and then batch the dataset.
    Dataset = dataset_.shuffle(dataset_length).batch(
        batch_size, drop_remainder=False).prefetch(2)

    return Dataset

if __name__ == "__main__":
    sys.path.insert(0, os.path.join(os.getcwd(), os.pardir))
    batch_size = 256
    shuffle_buffer_size = 256 ; cycle_length = 1
    ensemble_models = 10
    loss_mode = 'CLIP'
    type_list = ['origin', 'simple']
    data_type = type_list[0]
    state_list = ['REM', 'N2N3']
    state = state_list[1]
    ratio_list = [0.2, 0.4, 0.6, 0.8]
    basepath = os.path.join(os.getcwd(), os.pardir)
    recordpath = os.path.join(basepath, 'TFRecord')
    checkpath = os.path.join(basepath, 'checkpoint')
    utils.model.set_seeds(1642)

    log_format = "%(asctime)s - %(message)s"
    log_file_path = os.path.join(basepath, 'logger/EEG_Finetune_Sleep_transformer.log')
    logging.basicConfig(filename=log_file_path, format=log_format, level=logging.INFO)
    # Initialize params.
    for ratio in ratio_list:
        individual_models = []
        for i in range(ensemble_models):
            model_item = naive_transformer()
            checkpoint_path = os.path.join(checkpath, "checkpoint_EEG_pretrain_sleep_bj/transformer_new_model{:02d}.ckpt").format(i)
            model_item.load_weights(checkpoint_path)
            individual_models.append(model_item)

        path_record = os.path.join(recordpath, "TFRecord_EEG_SZ/"+"/sleep")
        path_result = os.path.join(basepath, 'EEG_Result/')

        path_train = [os.path.join(path_record,path_i) for path_i in os.listdir(path_record)\
                if path_i.startswith("train_data") ]; path_train.sort()
        path_train = path_train[0]
        path_val = [os.path.join(path_record,path_i) for path_i in os.listdir(path_record)\
                if path_i.startswith("val_data") ]; path_val.sort()
        path_val = path_val[0]
        path_test = [os.path.join(path_record,path_i) for path_i in os.listdir(path_record)\
                if path_i.startswith("test_data") ]; path_test.sort()
        path_test = path_test[0]


        train_dataset = load_train_data(path_train, batch_size, ratio)
        validation_dataset = load_data(path_val, batch_size)
        test_dataset = load_data(path_test, batch_size)

        optimizer = K.optimizers.Adam(learning_rate=1e-5)

        # training config
        epochs = 200
        train_loss = []
        val_loss = []
        test_loss = []
        train_accuracy = []
        val_accuracy = []
        test_accuracy = []
        no_improvement_epoches = 0
        tf.config.run_functions_eagerly(False)

        for epoch in range(epochs):
            print("\nStart of epoch %d" % (epoch,))
            start_time = time.time()
            
            # Train
            train_losses = 0.0
            Accuracy_train = []
            
            for step, input_data in enumerate(train_dataset):
                Loss_value = [] ; Result_Matrix = []
                for model in individual_models:
                    loss_value, Result_Matrix1 = train_step( \
                        model, input_data[0] ,input_data[1], input_data[2], input_data[3], optimizer)
                    Loss_value.append(loss_value) ; Result_Matrix.append(Result_Matrix1)
                loss_value = tf.reduce_mean(Loss_value, axis = 0)
                Result_Matrix1 = tf.reduce_mean(Result_Matrix, axis = 0)
                loss_value, Result_Matrix1 = loss_value.numpy(), Result_Matrix1.numpy()
                accuracy_train_1 = np.argmax(Result_Matrix1, axis=-1) == np.argmax(input_data[1], axis=-1)
                accuracy_train_1 = np.sum(accuracy_train_1) / accuracy_train_1.size
                Accuracy_train.append(accuracy_train_1)
                train_losses += loss_value
            train_losses = train_losses / len(Accuracy_train)
            Accuracy_train = sum(Accuracy_train) / len(Accuracy_train)
            
            print('Training Step Finished.')

            # Validation

            val_losses = 0.0
            
            Accuracy_val = []

            for step, input_data in enumerate(validation_dataset): 
                Loss_value = [] ; Result_Matrix = []
                for model in individual_models:
                    loss_value, Result_Matrix1 = validation_step( \
                        model, input_data[0] ,input_data[1], input_data[2], input_data[3])
                    Loss_value.append(loss_value) ; Result_Matrix.append(Result_Matrix1)
                loss_value = tf.reduce_mean(Loss_value, axis = 0)
                Result_Matrix1 = tf.reduce_mean(Result_Matrix, axis = 0)

                loss_value, Result_Matrix1 = loss_value.numpy(), Result_Matrix1.numpy()
                accuracy_val_1 = np.argmax(Result_Matrix1, axis=-1) == np.argmax(input_data[1], axis=-1)
                accuracy_val_1 = np.sum(accuracy_val_1) / accuracy_val_1.size
                Accuracy_val.append(accuracy_val_1)
                val_losses += loss_value

            val_losses = val_losses / len(Accuracy_val)
            Accuracy_val = sum(Accuracy_val) / len(Accuracy_val)
            if np.isnan(float(train_losses)):
                break
            print('Validation Step Finished.')

            # Test

            test_losses = 0.0
            
            Accuracy_test = []

            for step, input_data in enumerate(test_dataset): 
                Loss_value = [] ; Result_Matrix = []
                for model in individual_models:
                    loss_value, Result_Matrix1 = test_step( \
                        model, input_data[0] ,input_data[1], input_data[2], input_data[3])
                    Loss_value.append(loss_value) ; Result_Matrix.append(Result_Matrix1)
                loss_value = tf.reduce_mean(Loss_value, axis = 0)
                Result_Matrix1 = tf.reduce_mean(Result_Matrix, axis = 0)

                loss_value, Result_Matrix1 = loss_value.numpy(), Result_Matrix1.numpy()
                accuracy_test_1 = np.argmax(Result_Matrix1, axis=-1) == np.argmax(input_data[1], axis=-1)
                accuracy_test_1 = np.sum(accuracy_test_1) / accuracy_test_1.size
                Accuracy_test.append(accuracy_test_1)
                test_losses += loss_value

            test_losses = test_losses / len(Accuracy_test)
            Accuracy_test = sum(Accuracy_test) / len(Accuracy_test)
            print('Test Step Finished.')

            print("Training Loss over epoch: %.4f" % (float(train_losses),))
            print("Training Accuracy over epoch: %.2f" % (float(Accuracy_train*100),))

            print("Validation Loss: %.4f" % (float(val_losses),))
            print("Validation Accuracy over epoch: %.2f" % (float(Accuracy_val*100),))
            print("Test Loss: %.4f" % (float(test_losses),))
            print("Test Accuracy over epoch: %.2f" % (float(Accuracy_test*100),))

            msg = "Time taken: %.2fs" % (time.time() - start_time)
            print(msg) ; logging.info(msg)
            if len(val_loss) == 0:
                    no_improvement_epoches = 0
                    
            elif min(val_loss) > val_losses:
                    best_epoch = epoch
                    no_improvement_epoches = 0

            elif val_losses >= val_loss[-1]:
                no_improvement_epoches += 1

            
            train_loss.append(float(train_losses))
            train_accuracy.append(float(Accuracy_train))
            val_loss.append(float(val_losses))
            val_accuracy.append(float(Accuracy_val))
            test_loss.append(float(test_losses))
            test_accuracy.append(float(Accuracy_test))
        msg = 'Best Performance Epoch is at %d' % (best_epoch,)
        print(msg) ; logging.info(msg)
        time_now = time.strftime("%Y%m%d-%H%M", time.localtime())
        val_accuracy = np.round(np.array(val_accuracy, dtype=np.float32), decimals=4)
        test_accuracy = np.round(np.array(test_accuracy, dtype=np.float32), decimals=4)
        epoch_maxacc_idxs = np.where(val_accuracy == np.max(val_accuracy))[0]
        epoch_maxacc_idx = epoch_maxacc_idxs[np.argmax(test_accuracy[epoch_maxacc_idxs])]
        msg = (
            "Best test-accuracy ({:.2f}%) according to max validation-accuracy ({:.2f}%) at epoch {:d}."
        ).format(test_accuracy[epoch_maxacc_idx]*100.,
            val_accuracy[epoch_maxacc_idx]*100., epoch_maxacc_idx)
        print(msg) ; logging.info(msg)