import json
import mne
import os
import pandas as pd

import numpy as np
from pathlib import Path
import h5py
import pickle

input_path = "D:\\dataset\\figshare_shudb"
output_path = os.path.join(input_path, "processed_data")
events_folder_path = os.path.join(input_path, "events")

h5_output_name = 'figshare_shudb_combined'

os.makedirs(output_path, exist_ok=True)


def get_data_row(row):
    file_path = row['File Path']
    print(f"Attempting to read file: {file_path}")

    try:
        raw_data = mne.io.read_raw_edf(file_path, preload=True)
        print(f"Successfully loaded file: {file_path}")
        return raw_data
    except Exception as e:
        print(f"Error reading file {file_path}: {e}")
        return None


def read_tsv(file_path):
    with open(file_path, 'r', encoding='utf-8') as f:
        first_line = f.readline().strip().split('\t')

    if first_line[0].lower() == 'onset':
        df = pd.read_csv(file_path, sep='\t', header=0)
    else:
        df = pd.read_csv(file_path, sep='\t', header=None,
                         names=['onset', 'duration', 'trial_type', 'response_time', 'sample', 'value'])

    return df


def normalize_mne(raw):
    data = raw.get_data()
    mean = data.mean(axis=1, keepdims=True)
    std = data.std(axis=1, keepdims=True)
    raw._data = (data - mean) / std
    return raw


def app_func(row, output_dir):
    mne_raw = get_data_row(row)
    if mne_raw is None:
        return

    edf_filename = os.path.basename(row['File Path'])
    events_filename = edf_filename.replace('eeg.edf', 'events.tsv')
    events_file_path = os.path.join(events_folder_path, events_filename)

    if not os.path.exists(events_file_path):
        print(f"Events file {events_file_path} not found!")
        return

    events_df = read_tsv(events_file_path)

    start_time = (events_df['onset'].values - 1) / mne_raw.info['sfreq']
    duration = (events_df['duration'].values - 1) / mne_raw.info['sfreq']
    label = events_df['trial_type'].values

    annotations = mne.Annotations(onset=start_time, duration=duration, description=label)
    mne_raw.set_annotations(annotations)

    description_dict = {
        "original_description": mne_raw.info['description'],
        "eegunity_description": {
            "amplifier": "BrainAmp MR plus",
            "cap": "Ag/AgCl",
            "age": "unknown",
            "sex": "unknown",
            "handedness": "unknown"
        }
    }
    mne_raw.info['description'] = json.dumps(description_dict)

    sfreq = mne_raw.info['sfreq']
    h_freq = min((sfreq / 2) - 1, 75)
    notch_freq = 50 if sfreq > 100 else (sfreq / 2) - 1

    mne_raw = mne_raw.filter(l_freq=0.1, h_freq=h_freq)
    mne_raw = mne_raw.notch_filter(freqs=notch_freq)
    mne_raw = mne_raw.resample(200)
    mne_raw = normalize_mne(mne_raw)

    filename = os.path.basename(row['File Path'])
    output_fif_path = f"{output_dir}/{filename[:-4]}_processed_raw.fif"
    mne_raw.save(output_fif_path, overwrite=True)
    print(f"Saved processed file: {output_fif_path}")


class h5Dataset:
    def __init__(self, path: Path, name: str) -> None:
        self.__name = name
        self.__f = h5py.File(path / f'{name}.hdf5', 'a')

    def addGroup(self, grpName: str):
        return self.__f.create_group(grpName)

    def addDataset(self, grp, dsName, arr, chunks=None, **kwargs):
        return grp.create_dataset(dsName, data=arr, chunks=chunks, **kwargs)

    def addAttributes(self, src, attrName, attrValue):
        src.attrs[attrName] = attrValue

    def save(self):
        self.__f.close()


def convert_all_fif_to_h5(fif_dir, output_dir, h5_name):
    dataset = h5Dataset(Path(output_dir), name=h5_name)
    event_info = {}

    for file in os.listdir(fif_dir):
        if file.endswith('_processed_raw.fif'):
            file_path = os.path.join(fif_dir, file)
            raw_data = mne.io.read_raw_fif(file_path, preload=True)

            events, event_id = mne.events_from_annotations(raw_data)

            epochs = mne.Epochs(
                raw_data, events=events, event_id=event_id,
                tmin=0.0, tmax=4.0, baseline=None, preload=True
            )

            file_name = file.replace('_processed_raw.fif', '')
            grp = dataset.addGroup(grpName=file_name)

            info_bytes = pickle.dumps(raw_data.info)
            info_array = np.frombuffer(info_bytes, dtype='uint8')
            dataset.addDataset(grp, 'info', info_array)

            for event in event_id:
                try:
                    event_epochs = epochs[event]
                    epoch_data = event_epochs.get_data()

                    if epoch_data.ndim != 3:
                        raise ValueError("Epoch data is not three-dimensional.")

                    dset = dataset.addDataset(grp, event, epoch_data, chunks=epoch_data.shape)
                    dataset.addAttributes(dset, 'rsFreq', raw_data.info['sfreq'])
                    dataset.addAttributes(dset, 'chOrder', event_epochs.info['ch_names'])

                    event_info[event] = event_info.get(event, 0) + len(event_epochs)
                    print(f"Saved event {event} in {file_name}")
                except Exception as e:
                    print(f"Error processing event {event} in {file_name}: {e}")

    dataset.save()
    print(" All FIF files have been converted to one HDF5.")


if __name__ == '__main__':
    file_list = [f for f in os.listdir(input_path) if f.endswith('.edf')]

    for file in file_list:
        row = pd.Series({'File Path': os.path.join(input_path, file)})
        app_func(row, output_path)

    print(" All FIF files processed")

    convert_all_fif_to_h5(output_path, output_path, h5_output_name)
