import scipy.io as scio
import numpy as np
import os
import mne
import copy
import json
import pickle
from pathlib import Path
import h5py

input_path = "D:\\dataset\\bcic_iv_2b"
output_path = os.path.join(input_path, "processed_data")
h5_output_name = "bcic_iv_2b_combined"
event_id = {
    'Cue onset left (class 1)': 10,
    'Cue onset right (class 2)': 11
}

subject_dict = {
    'B01': {'sex': 'female', 'age': '20'},
    'B02': {'sex': 'male', 'age': '24'},
    'B03': {'sex': 'male', 'age': '21'},
    'B04': {'sex': 'female', 'age': '23'},
    'B05': {'sex': 'male', 'age': '24'},
    'B06': {'sex': 'female', 'age': '21'},
    'B07': {'sex': 'male', 'age': '27'},
    'B08': {'sex': 'female', 'age': '32'},
    'B09': {'sex': 'male', 'age': '26'}
}


def get_data_row(row):
    return mne.io.read_raw_gdf(row['File Path'], preload=True)


def normalize_mne(raw):
    data = raw.get_data()
    mean = data.mean(axis=1, keepdims=True)
    std = data.std(axis=1, keepdims=True)
    raw._data = (data - mean) / std
    return raw


def update_events(original_events, original_event_id, new_event_id):
    original_to_new_mapping = {}
    for original_key, original_value in original_event_id.items():
        for new_key, new_value in new_event_id.items():
            if original_key in new_key:
                original_to_new_mapping[original_value] = new_value

    new_events = copy.deepcopy(original_events)
    for event in new_events:
        if event[2] in original_to_new_mapping:
            event[2] = original_to_new_mapping[event[2]]
    return new_events


def preprocess_and_save_fif(file_path: str, output_dir: str):
    mne_raw = mne.io.read_raw_gdf(file_path, preload=True)
    subject_id = os.path.basename(file_path)[:3]

    description_dict = {
        "original_description": mne_raw.info['description'],
        "eegunity_description": {
            "amplifier": "unknown",
            "cap": "Ag/AgCl",
            "age": subject_dict[subject_id]['age'],
            "sex": subject_dict[subject_id]['sex'],
        }
    }
    mne_raw.info['description'] = json.dumps(description_dict)

    montage = mne.channels.make_standard_montage('standard_1020')
    mne_raw.info.set_montage(montage, on_missing='ignore')
    mne_raw.set_channel_types({'EOG:ch01': 'eog', 'EOG:ch02': 'eog', 'EOG:ch03': 'eog'})

    new_event_id = {
        'Rejected trial': 1,
        'Horizontal eye movement': 2,
        'Vertical eye movement': 3,
        'Eye rotation': 4,
        'Eye blinks': 5,
        'Idling EEG (eyes open)': 6,
        'Idling EEG (eyes closed)': 7,
        'Start of a new run': 8,
        'Start of a trial': 9,
        'Cue onset left (class 1)': 10,
        'Cue onset right (class 2)': 11,
        'BCI feedback (continuous)': 12,
        'Cue unknown': 13,
    }

    original_events, original_event_id = mne.events_from_annotations(mne_raw)
    new_events = update_events(original_events, original_event_id, new_event_id)

    file_base, file_ext = os.path.splitext(file_path)
    if file_base.endswith('E') and file_ext == '.gdf':
        unknown_id = new_event_id['Cue unknown']
        mat_filepath = f"{file_base}.mat"
        if os.path.exists(mat_filepath):
            mat_data = scio.loadmat(mat_filepath)
            values_from_mat = mat_data['classlabel'].flatten() + 9
            replacement_indices = np.where(new_events[:, -1] == unknown_id)[0]
            if len(replacement_indices) == len(values_from_mat):
                new_events[replacement_indices[:len(values_from_mat)], 2] = values_from_mat

    event_desc = {v: k for k, v in new_event_id.items()}
    annotations = mne.annotations_from_events(
        events=new_events,
        sfreq=mne_raw.info['sfreq'],
        event_desc=event_desc
    )
    mne_raw.set_annotations(annotations)

    sfreq = mne_raw.info['sfreq']
    h_freq = min((sfreq / 2) - 1, 75)
    notch_freq = 50 if sfreq > 100 else (sfreq / 2) - 1

    mne_raw = mne_raw.filter(l_freq=0.1, h_freq=h_freq, n_jobs=4)
    mne_raw = mne_raw.notch_filter(freqs=notch_freq, n_jobs=4)
    mne_raw = mne_raw.resample(200, n_jobs=4)
    mne_raw = normalize_mne(mne_raw)

    filename = os.path.basename(file_path).replace('.gdf', '_processed_raw.fif')
    mne_raw.save(os.path.join(output_dir, filename), overwrite=True)


class h5Dataset:
    def __init__(self, path: Path, name: str) -> None:
        self.__name = name
        self.__f = h5py.File(path / f'{name}.hdf5', 'a')

    def addGroup(self, grpName: str):
        return self.__f.create_group(grpName)

    def addDataset(self, grp, dsName, arr, chunks=None, **kwargs):
        return grp.create_dataset(dsName, data=arr, chunks=chunks, **kwargs)

    def addAttributes(self, src, attrName, attrValue):
        src.attrs[attrName] = attrValue

    def save(self):
        self.__f.close()


def convert_all_fif_to_h5(fif_dir, output_dir, h5_name):
    dataset = h5Dataset(Path(output_dir), name=h5_name)
    event_info = {}

    for file in os.listdir(fif_dir):
        if file.endswith('_processed_raw.fif'):
            file_path = os.path.join(fif_dir, file)
            raw_data = mne.io.read_raw_fif(file_path, preload=True)

            events, _ = mne.events_from_annotations(raw_data)
            event_ids_in_data = set(events[:, 2])
            desired_ids = set(event_id.values())

            if not (event_ids_in_data & desired_ids):
                print(f" No matching events in {file}")
                continue

            epochs = mne.Epochs(
                raw_data, events=events, event_id=event_id,
                tmin=0.0, tmax=4.0, baseline=None, preload=True
            )

            file_name = file.replace('_processed_raw.fif', '')
            grp = dataset.addGroup(grpName=file_name)

            info_bytes = pickle.dumps(raw_data.info)
            info_array = np.frombuffer(info_bytes, dtype='uint8')
            dataset.addDataset(grp, 'info', info_array)

            for event in event_id:
                try:
                    event_epochs = epochs[event]
                    epoch_data = event_epochs.get_data()

                    if epoch_data.ndim != 3:
                        raise ValueError("Epoch data is not three-dimensional.")

                    dset = dataset.addDataset(grp, event, epoch_data, chunks=epoch_data.shape)
                    dataset.addAttributes(dset, 'rsFreq', raw_data.info['sfreq'])
                    dataset.addAttributes(dset, 'chOrder', event_epochs.info['ch_names'])

                    event_info[event] = event_info.get(event, 0) + len(event_epochs)
                    print(f"Saved event {event} in {file_name}")
                except Exception as e:
                    print(f"Error processing event {event} in {file_name}: {e}")

    dataset.save()
    print(" All FIF files have been converted to one HDF5.")


if __name__ == '__main__':

    os.makedirs(output_path, exist_ok=True)
    for file in os.listdir(input_path):
        if file.endswith(".gdf"):
            preprocess_and_save_fif(os.path.join(input_path, file), output_path)

    convert_all_fif_to_h5(output_path, output_path, h5_output_name)
