# Pre-Processing
import glob
import os
import numpy as np
from scipy.io import loadmat
from sklearn.metrics import accuracy_score
from scipy.signal import butter, filtfilt
import ml_collections
from scipy.linalg import sqrtm


# ONSET : benchmark 35, eldbeta 35 beta  33
def segment_eeg(dataset_config, fs=250, duration=1.0, order=4, onset=35):
    if dataset_config.dataset == "beta":
        onset = 33

    eeg_files = sorted(glob.glob(f"{dataset_config.data_folder}/*.mat"), key=lambda x: int(x.split('/')[-1].split('.')[0][1:]))
    
    dataset_config.num_subjects = len(eeg_files)
    onset = 125 + onset  # 0.5s rest (125 points) + 0.14s delay (35 points)
    end = int(duration * fs)
    if dataset_config.channels == 3:
        permutation = [60, 61, 62]
    if dataset_config.channels == 9:
        permutation = [47, 53, 54, 55, 56, 57, 60, 61, 62]  # 【01234；】
    elif dataset_config.channels == 21:
        permutation =  [47, 46, 48, 45, 49, 44, 50, 43, 51, 55, 54, 56, 53, 57, 52, 58, 61, 60, 62, 59, 63]
    elif dataset_config.channels == 32:
        permutation = [37, 36, 38, 35, 39, 34, 40, 33, 41, 47, 46, 48, 45, 49, 44, 50, 43, 51, 55, 54, 56, 53, 57, 52, 58, 61, 60, 62, 59, 63]
    elif dataset_config.channels == 64:
        permutation = [i for i in range(64)]


    assert dataset_config.channels == len(permutation), "channels must be equal to permutation length"
    X, Y = [], []  # empty data and labels
    for record in eeg_files[:]:
        print(f"Processing {record}")
        data = loadmat(record)

        # samples, channels, trials, targets
        # print(data['data'].shape)

        if dataset_config.dataset == "eldbeta":
            # eldbeta data is in the shape of (samples, channels, trials, targets)
            eeg = data["data"][0][0][0][0][0][0].transpose((1, 0, 3, 2))
            assert eeg.shape == (1500, 64, 7, 9)
            # eeg = data["data"].transpose((1, 0, 3, 2))
        elif dataset_config.dataset == "benchmark":
            eeg = data["data"].transpose((1, 0, 3, 2))
            assert eeg.shape == (1500, 64, 6, 40)
        elif dataset_config.dataset == "beta":
            eeg = data["data"][0][0][0].transpose((1, 0, 2, 3))[:750, :, :, :]
            assert eeg.shape == (750, 64, 4, 40)

        eeg = eeg[:, permutation, :, :]
        # print(data['eeg'].shape)
        # eeg = data["eeg"].transpose((2, 1, 3, 0))
        # filter data
        eeg_multi_band = []
        for i in range(1, 4):
            eeg = filter_band(i, eeg, fs=dataset_config.fs)
            eeg_multi_band.append(np.expand_dims(eeg, axis=-1))
        eeg = np.concatenate(eeg_multi_band, axis=-1)

        # samples, channels, trials,targets,multiband
        eeg = eeg.transpose(
            (0, 1, 4, 2, 3)
        )  # samples, channels, multiband,trials,targets

        # segment data
        eeg = eeg[onset : onset + end, :, :, :, :]
        samples, channels, band_num, blocks, targets = eeg.shape
        dataset_config.channels = channels
        dataset_config.targets = targets
        dataset_config.blocks = blocks
        dataset_config.samples = samples
        y = np.tile(np.arange(1, targets + 1), (blocks, 1))
        y = y.reshape((1, blocks * targets), order="F")

        X.append(
            eeg.reshape((samples, channels, band_num, blocks * targets), order="F")
        )
        Y.append(y)
        # print(f"X shape: {X[-1].shape}")
        # print(f"Y shape: {Y[-1].shape}")
    dataset_config.num_subjects = len(Y)
    X = np.array(X, dtype=np.float32, order="F")
    Y = np.array(Y, dtype=np.float32).squeeze()
    return X, Y


def filter_band(i_band, data, fs):
    band = [8 * i_band, 90]  # 1, 2, 3
    eeg_band_i = filter_eeg(data, fs=fs, band=band)
    return eeg_band_i


def filter_eeg(data, fs=250, band=[5.0, 45.0], order=4, axis=0):
    B, A = butter(order, np.array(band) / (fs / 2), btype="bandpass")
    return filtfilt(B, A, data, axis=axis)


def preprocess_data(dataset_config, duration=0.2):
    X, Y = segment_eeg(dataset_config, order=4, fs=dataset_config.fs, duration=duration)
    print(f"X shape: {X.shape}")  # subject x samples x channels x trials
    print(f"Y shape: {Y.shape}")
    return X, Y


def form_dataset_config(dataset_name, config):
    # config = ml_collections.ConfigDict(vars(args))
    config.dataset = dataset_name
    config.multi_band = 3
    if config.dataset == "benchmark":
        config.data_folder = ""
        config.target = 40
        config.trials_train = config.trials-config.test_trials
        config.fs = 250
    if config.dataset == "beta":
        # permute input numpy shape

        config.data_folder = ""
        config.trials = 4
        config.target = 40
        config.trials_train = config.trials-config.test_trials
        config.fs = 250

    return config


def add_cov_preprocess(x, config):
    x_train_local = x[:, :, :, :, :, :config.trials_train]
    x_test_local = x[:, :, :, :, :, config.trials_train:]
    x_train_local_mean = np.mean(x_train_local, axis=-1)
    x_train_local_mean = x_train_local_mean.transpose(3,0,2, 1, 4)
    x_cov_list = []
    for i in range(x_train_local_mean.shape[0]):
        x_cov_list.append([])
        for j in range(x_train_local_mean.shape[1]):
            x_local_subband = x_train_local_mean[i, j, :, :, :] 
            x_cov = np.linalg.inv(np.cov(x_local_subband.reshape(x_local_subband.shape[0], -1), rowvar=True))
            x_cov = sqrtm(x_cov)
            x_cov_list[-1].append(x_cov)
    x_train_local_output_list = []
    x_test_local_output_list = []
    for i in range(len(x_cov_list)):
        x_train_local_output_subject = []
        x_test_local_output_subject = []
        for j in range(len(x_cov_list[0])):
            x_train_local_output = np.einsum('bcde,cf->bfde', x_train_local[j, :, :, i, :, :], x_cov_list[i][j])
            x_test_local_output = np.einsum('bcde,cf->bfde', x_test_local[j, :, :, i, :, :], x_cov_list[i][j])
            x_train_local_output_subject.append(np.expand_dims(x_train_local_output, axis=(0,3)))
            x_test_local_output_subject.append(np.expand_dims(x_test_local_output, axis=(0,3))) 
        x_train_local_output_list.append(np.concatenate(x_train_local_output_subject, axis=0))
        x_test_local_output_list.append(np.concatenate(x_test_local_output_subject, axis=0))
    x_train_local_output_list = np.concatenate(x_train_local_output_list, axis=3)
    x_test_local_output_list = np.concatenate(x_test_local_output_list, axis=3)
    res = np.concatenate([x_train_local_output_list, x_test_local_output_list], axis=5)

    return res


if __name__ == "__main__":
    import argparse

    parser = argparse.ArgumentParser()
    parser.add_argument("--dataset", type=str, default="benchmark")
    args = parser.parse_args()
    dataset_config = form_dataset_config(args.dataset, args)
    preprocess_data(dataset_config, duration=0.2)
