
import re
import h5py
import numpy as np
import pandas as pd
import mne
from joblib import Parallel, delayed
# from transformers import AutoModel, AutoTokenizer

import sys
import os

cwd = os.getcwd()
parent_dir = os.path.dirname(cwd)
sys.path.append(parent_dir)

def fix_tuh_channel_names(name: str) -> str:
    
    # Remove "EEG " and "-REF" from channel names, and fit to standard naming
    if "-REF" in name:
        name = name.replace("EEG ", "").replace("-REF", "").replace("FP", "Fp").replace("Z", "z")
    else: # or, "-LE" in name:
       name = name.replace("EEG ", "").replace("-LE", "").replace("FP", "Fp").replace("Z", "z")
 
    return name

def parse_age_and_sex_from_edf_header(file_path):
    header = read_edf_header(file_path)
    # bytes 8 to 88 contain ascii local patient identification
    # see https://www.teuniz.net/edfbrowser/edf%20format%20description.html
    patient_id = header[8:].decode("ascii")
    age = -1
    found_age = re.findall(r"Age:(\d+)", patient_id)
    if len(found_age) == 1:
        age = int(found_age[0])
    sex = "X"
    found_sex = re.findall(r"\s([F|M])\s", patient_id)
    if len(found_sex) == 1:
        sex = found_sex[0]
    return age, sex    

def read_edf_header(file_path):
    f = open(file_path, "rb")
    header = f.read(88)
    f.close()
    return header

def process_subject(subset, subject, source, target, sfreq, verbose):
    df = pd.DataFrame(columns=[
        "subject_id", "session", "time", "age", "sex", "montage", "samples", "subset"
    ])
    
    subset_path = os.path.join(source, subset)
    subject_path = os.path.join(subset_path, subject)
    sessions = sorted(os.listdir(subject_path))

    for session in sessions:
        session_path = os.path.join(subject_path, session)
        montage = os.listdir(session_path)[0]
        montage_path = os.path.join(session_path, montage)
        recordings = sorted(os.listdir(montage_path))

        eegs = []
        for recording in recordings:
            full_path = os.path.join(montage_path, recording)
            age, sex = parse_age_and_sex_from_edf_header(full_path)
            time = recording.split("_t")[-1].rstrip(".edf")
            eeg = load_and_preprocess_TUEG_edf(full_path, sfreq=sfreq, verbose=verbose)
            if eeg is not None:
                eeg_data = eeg.get_data().astype(np.float16)
                #eegs.append(eeg.get_data().astype(np.float16))
                if eeg_data.shape[1] < (1*60)*sfreq:
                    print("SKIP: Recording under 1 minute", eeg_data.shape[1]/sfreq, full_path, flush=True)
                    continue
                
                target_dir = os.path.join(target, subset, subject, session)
                os.makedirs(target_dir, exist_ok=True)
                target_name = f"{subject}_{session}_t{time}.npy"
                np.save(os.path.join(target_dir, target_name), eeg_data)
                
                new_row = pd.Series({
                    "subject_id": subject,
                    "session": session,
                    "time": time,
                    "age": age,
                    "sex": sex,
                    "montage": montage,
                    "samples": eeg_data.shape[1],
                    "subset": subset,
                })
                df = pd.concat([df, pd.DataFrame([new_row])], ignore_index=True)
            else:
                print("SKIP: No valid EEGs", full_path, flush=True)
                continue
                
    return df

def load_and_preprocess_TUEG_edf(fp, sfreq=100, verbose="critical"):

    f = mne.io.read_raw_edf(fp, preload=True, verbose="critical").copy()
    total_duration = f.times[-1]
    
    if total_duration > (150*60): # longer than 2.5 hours
        print("SKIP: Individual recording longer than 2.5 hours", total_duration, fp, flush=True)
        return None

    # filter
    f = f.filter(l_freq=0.1, h_freq=49.0, verbose=verbose)

    # crop: remove first 10 seconds and ensure duration is multiple of 60 seconds
    new_endpoint = total_duration - (total_duration - 10) % 60
    try:
        f = f.crop(tmin=10, tmax=new_endpoint, verbose=verbose)
    except:
        print("SKIP: Error when cropping. Total_duration ", total_duration, fp, flush=True)
        return None

    # rename the channels
    f.rename_channels(fix_tuh_channel_names, verbose=verbose)

    # create bipolar channels and leave out the others
    try:
        f = mne.set_bipolar_reference(f,
            anode=['Fp1', 'F7', 'T3', 'T5', 'Fp2', 'F8', 'T4', 'T6', 'T3', 'C3', 'Cz', 'C4', 'Fp1', 'F3', 'C3', 'P3', 'Fp2', 'F4', 'C4', 'P4'],
            cathode=['F7', 'T3', 'T5', 'O1', 'F8', 'T4', 'T6', 'O2', 'C3', 'Cz', 'C4', 'T4', 'F3', 'C3', 'P3', 'O1', 'F4', 'C4', 'P4', 'O2'],
            verbose=verbose)
        f.pick([
            "Fp1-F7", "F7-T3", "T3-T5", "T5-O1", 
                "Fp2-F8", "F8-T4", "T4-T6", "T6-O2", 
                "T3-C3", "C3-Cz", "Cz-C4", "C4-T4", 
                "Fp1-F3", "F3-C3", "C3-P3", "P3-O1", 
                "Fp2-F4", "F4-C4", "C4-P4", "P4-O2"
        ], verbose=verbose)
        
    except:
        print("SKIP: Error regarding channels. ", fp, flush=True)
        return None

    # rescale
    def scale_data(x):
        # Important: uV would be *1e6
        # However, this gives stdev ~= 15
        # For FP16 training, we may enjoy greater precision by reducing the stdev to ~1.5 
        return x * 1e5 
    f.apply_function(fun=scale_data, picks="all", channel_wise=False)

    # clip
    def clip_data(x):
        return np.clip(x, a_min=-80, a_max=80)
    f.apply_function(fun=clip_data, picks="all", channel_wise=False)

    # average-reference
    f.resample(sfreq=sfreq, n_jobs=1, verbose=verbose)

    return f

def preprocess_TUEG():
    verbose = "critical"
    sfreq = 100
    source = "/path/to/TUEG/edf/"
    target = "/path/to/TUEG/deriv/"

    subject_sets = sorted(os.listdir(source))

    for subset in subject_sets:
        print("Starting subset ", subset, flush=True)
        subset_path = os.path.join(source, subset)
        subjects = sorted(os.listdir(subset_path))

        # Parallelize subject processing
        results = Parallel(n_jobs=6)(delayed(process_subject)(subset, subject, source, target, sfreq, verbose) for subject in subjects)
        
        subset_df = pd.concat(results, ignore_index=True)
        # If the meta CSV file exists, load it, concatenate the new data, and save it
        if os.path.exists(f"{target}metadata.csv"):
            existing_df = pd.read_csv(f"{target}metadata.csv")
            combined_df = pd.concat([existing_df, subset_df], ignore_index=True)
        else:
            combined_df = subset_df
        
        # Save the combined DataFrame, overwriting the existing file
        combined_df.to_csv(f"{target}metadata.csv", index=False)
        print(f"Saved intermediate results for subset: {subset}", flush=True)
            
def TUEG_to_h5():          
        
    epoch_length = 2000
    epoch_wise = False
    source = "/path/to/TUEG/deriv/"
    suffix = f"_EPOCHS_{int(epoch_length/100)}" if epoch_wise else ""
    target = f"/path/to/TUEG/data/TUEG_100Hz_TCP_{suffix}.h5"
    
    n_channels = 20

    subject_sets = sorted(os.listdir(source))

    with h5py.File(target, "w") as f:
        
        # iterate over sets
        for subset in subject_sets:
            print(subset)

            subset_path = os.path.join(source, subset)
            subjects = sorted(os.listdir(subset_path))
            
            for subject in subjects:
                
                subject_path = os.path.join(subset_path, subject)
                sessions = sorted(os.listdir(subject_path))
                
                for session in sessions:
                    
                    session_path = os.path.join(subject_path, session)
                    times = sorted(os.listdir(session_path))
                    
                    eegs = []
                    for i, time in enumerate(times):
                        d = np.load(os.path.join(session_path,time))
                        eegs.append(d)
                    eeg_concat = np.concatenate(eegs, axis=1)
                    
                    grp = f.create_group(f'{subject}_{session}')
                    dset = grp.create_dataset("eeg_data", data=eeg_concat, chunks=(n_channels, 500),
                                                compression="gzip", compression_opts=4)
                    
                    grp.attrs["n_timepoints"] = eeg_concat.shape[1]
                    
                    
def TUEG_to_h5_epochs():   
    """Saves the preprocessed EEG as an .h5 file."""            
    epoch_wise = True
    epoch_length = 2000 # 20 sec @ 100 Hz
    max_length = 270000 # 45 minutes
     
    source = "/path/to/TUEG/deriv/"
    suffix = f"_EPOCHS_{int(epoch_length/100)}s" if epoch_wise else ""
    target = f"/path/to/TUEG/data/TUEG_timewise_100Hz_TCP{suffix}.h5"
    
    n_channels = 20

    subject_sets = sorted(os.listdir(source))
    subject_sets = [f for f in subject_sets if not f.endswith('.csv')]

    with h5py.File(target, "w") as f:
        
        dset = f.create_dataset("features", (0, 20, epoch_length), maxshape=(None,20,epoch_length),
                chunks=(1, n_channels, epoch_length), dtype='float16', 
                compression="gzip", compression_opts=4)
        
        epochs = []
        long_subject_id = []
        subject_idx = []
        session_id = []
        time_id = []
        sample_idx = []
        
        sample_count = 0
        unique_subject_count = 0
                
        # iterate over sets
        for subset in subject_sets:
            print(subset)

            subset_path = os.path.join(source, subset)
            subjects = sorted(os.listdir(subset_path))
            
            for subject in subjects:
                
                subject_path = os.path.join(subset_path, subject)
                sessions = sorted(os.listdir(subject_path))
                
                for session in sessions:
                    
                    session_path = os.path.join(subject_path, session)
                    times = sorted(os.listdir(session_path))
                    
                    for time in times:
                        eeg = np.load(os.path.join(session_path,time))[:, :max_length]
                        n_epochs = int(eeg.shape[1] / epoch_length)
                        crops = eeg.reshape(n_channels, n_epochs, epoch_length) # C,E,L
                        crops = np.transpose(crops, (1,0,2)) # into E,C,L
                        dset.resize(dset.shape[0] + n_epochs, axis=0)
                        dset[-n_epochs:, :, :] = crops
                        
                        long_subject_id.extend([subject] * n_epochs) # aaaaaaaa
                        subject_idx.extend([unique_subject_count] * n_epochs) # 0
                        session_id.extend([session] * n_epochs) # s001
                        time_id.extend([time.split("_")[-1].replace(".npy", "")] * n_epochs) # t000
                        sample_idx.extend([sample_count] * n_epochs) # 0
                        
                        epochs.append(n_epochs)
                        
                        sample_count += 1
                unique_subject_count += 1
        
        f.create_dataset("dataset_std", data = 1.)
        f.create_dataset("dataset_mean", data = 0.)
        f.create_dataset("long_subject_id", data = long_subject_id)
        # careful! downstream we need subject_id to map to single eeg files hence this unintuitive assignment
        f.create_dataset("subject_ids", data = sample_idx) 
        f.create_dataset("session_ids", data = session_id)
        f.create_dataset("time_ids", data = time_id)
        f.create_dataset("unique_subject_ids", data=subject_idx)
        f.create_dataset("epochs", data=epochs)
        f.close()
        
# TUEG_to_h5_epochs()

def generate_mappings():
    """Function used to create mappings between long TUEG subject identifiers of the form
    aaaaaaa_s001_t001 (time id) 
    aaaaaaa_s001 (session id) 
    aaaaaaa (subject id)
    to integer indices of corresponding EEG crops in the .h5 dataset.
    """
    import h5py
    from itertools import zip_longest

    file = h5py.File("/path/to/TUEG/data/TUEG_timewise_100Hz_TCP_EPOCHS_20s.h5", "r")
    h5_long = [s.decode('utf-8') for s in file["long_subject_id"][:]]
    h5_ses = [s.decode('utf-8') for s in file["session_ids"][:]]
    h5_time = [s.decode('utf-8') for s in file["time_ids"][:]]
    h5_idx = file["subject_ids"][:]
    file.close()

    h5_longses = [a + "_" + b for a, b in zip_longest(h5_long, h5_ses, fillvalue='_')]
    h5_longsestime = [a + "_" + b for a, b in zip_longest(h5_longses, h5_time, fillvalue='_')]

    subject_mapping = {}
    session_mapping = {}
    time_mapping = {}

    for unique_string in np.unique(h5_long):
        locations = np.where(np.isin(h5_long, unique_string))[0]
        idxes = np.unique(h5_idx[locations])
        subject_mapping[unique_string] = sorted(set(idxes))
    np.save("path/to/TUEG/data/20s_subject_id_timewise_mapping.npy", subject_mapping)
        
    for unique_string in np.unique(h5_longses):
        locations = np.where(np.isin(h5_longses, unique_string))[0]
        idxes = np.unique(h5_idx[locations])
        session_mapping[unique_string] = sorted(set(idxes))
    np.save("/path/to/TUEG/data/20s_session_id_timewise_mapping.npy", session_mapping)
        
    for unique_string in np.unique(h5_longsestime):
        locations = np.where(np.isin(h5_longsestime, unique_string))[0]
        idxes = np.unique(h5_idx[locations])
        time_mapping[unique_string] = sorted(set(idxes))
    np.save("/path/to/TUEG/data/20s_time_id_timewise_mapping.npy", time_mapping)
    
# generate_mappings()