import os
import ast
import h5py
import numpy as np
import pandas as pd
from tqdm import tqdm
import wfdb

# Import filtering functions from scipy.signal
from scipy.signal import butter, iirnotch, filtfilt

# ------------------------------------------------------
# Global parameters (adjust as needed)
# ------------------------------------------------------
FS = 500            # Sampling rate
WINDOW_SIZE = 1024  # Sliding window size
STEP_SIZE = 512     # Step size for sliding window
CHUNK_SIZE = 500    # Number of samples to process and write per batch

# Only consider the five major superclasses
SUPERCLASS_MAP = {
    "NORM": 0,
    "MI":   1,
    "STTC": 2,
    "CD":   3,
    "HYP":  4,
}

# ------------------------------------------------------
# 1. Construct SCP code to diagnostic class mapping
# ------------------------------------------------------
def load_diag_map(scp_csv_path):
    """
    Load scp_statements.csv, keep only rows with diagnostic=1,
    and generate a {code: diagnostic_class} mapping dictionary.
    """
    df_scp = pd.read_csv(scp_csv_path, index_col=0)
    df_scp = df_scp[df_scp["diagnostic"] == 1]

    diag_map = {}
    for code, row in df_scp.iterrows():
        diag_map[code] = row["diagnostic_class"]
    return diag_map

# ------------------------------------------------------
# 2. Parse scp_codes and select the label with highest confidence >= threshold
# ------------------------------------------------------
def parse_scp_codes_single_label_with_threshold(scp_str, diag_map, threshold=100.0):
    """
    1) Deserialize scp_codes string.
    2) Keep only codes present in diag_map.
    3) Select the code with highest likelihood.
    4) Return None if highest likelihood < threshold.
    5) Return None if diagnostic class not in SUPERCLASS_MAP.
    6) Otherwise return the corresponding integer label.
    """
    scp_dict = ast.literal_eval(scp_str)

    best_likelihood = -1
    best_class = None
    for code, lk in scp_dict.items():
        if code not in diag_map:
            continue
        if lk > best_likelihood:
            best_likelihood = lk
            best_class = diag_map[code]

    if best_likelihood < threshold:
        return None
    if best_class not in SUPERCLASS_MAP:
        return None
    return SUPERCLASS_MAP[best_class]

# ------------------------------------------------------
# 3. Read ECG data
# ------------------------------------------------------
def read_ecg_data(record_path):
    """
    Read .dat/.hea files using wfdb.rdsamp => ndarray with shape (n_samples, n_channels).
    Return the transposed array with shape (12, n_samples).
    """
    signal, _ = wfdb.rdsamp(record_path)
    return signal.T

# ------------------------------------------------------
# 4. ECG filtering function (example: 0.5-40 Hz bandpass + 50 Hz notch)
# ------------------------------------------------------
def filter_ecg(ecg, fs=500):
    """
    Apply bandpass and notch filters to the ECG signal.
    ecg: shape [12, n_samples]
    Returns: filtered ECG with the same shape.
    """
    lowcut, highcut, order = 0.5, 40.0, 4
    nyq = 0.5 * fs
    low, high = lowcut / nyq, highcut / nyq
    b_band, a_band = butter(order, [low, high], btype='band')

    notch_freq, q = 50.0, 30.0
    b_notch, a_notch = iirnotch(notch_freq / nyq, q)

    ecg_filtered = np.zeros_like(ecg)
    for i in range(ecg.shape[0]):
        data = ecg[i]
        data = filtfilt(b_band, a_band, data)
        data = filtfilt(b_notch, a_notch, data)
        ecg_filtered[i] = data
    return ecg_filtered

# ------------------------------------------------------
# 5. Sliding window segmentation + normalization
# ------------------------------------------------------
def sliding_window_ecg(ecg, window_size=WINDOW_SIZE, step_size=STEP_SIZE):
    """
    Segment ECG signal using a sliding window.
    ecg: [12, n_samples]
    Returns: list of segments each with shape [12, window_size].
    """
    segs = []
    n = ecg.shape[1]
    for start in range(0, n - window_size + 1, step_size):
        seg = ecg[:, start:start + window_size]
        m = seg.mean(axis=1, keepdims=True)
        s = seg.std(axis=1, keepdims=True)
        s[s < 1e-7] = 1.0
        segs.append((seg - m) / s)
    return segs

# ------------------------------------------------------
# 6. HDF5 dataset creation and writing (datasets 'input' and 'label')
# ------------------------------------------------------
def create_h5_datasets(f, data_shape, chunk_size):
    ds = f.create_dataset(
        'input', shape=(0,) + data_shape[1:], maxshape=(None,) + data_shape[1:],
        chunks=(chunk_size,) + data_shape[1:], dtype=np.float32)
    ls = f.create_dataset(
        'label', shape=(0,), maxshape=(None,), chunks=(chunk_size,), dtype=np.int32)
    return ds, ls

def append_h5(ds, ls, db, lb):
    cur = ds.shape[0]
    new = cur + db.shape[0]
    ds.resize((new,) + ds.shape[1:])
    ls.resize((new,))
    ds[cur:new] = db
    ls[cur:new] = lb

# ------------------------------------------------------
# 7. Core pipeline: 12 channels => filtering => sliding window => writing => verification
# ------------------------------------------------------
def process_split(df, root, out, diag_map, thr=100.0):
    buf_d, buf_l = [], []
    # Probe the first valid sample
    for _, row in df.iterrows():
        lbl = parse_scp_codes_single_label_with_threshold(row['scp_codes'], diag_map, thr)
        if lbl is not None:
            sig = read_ecg_data(os.path.join(root, row['filename_hr']))
            if sig.shape[1] >= WINDOW_SIZE:
                break
    with h5py.File(out, 'w') as f:
        dset, lset = create_h5_datasets(f, (None, 12, WINDOW_SIZE), CHUNK_SIZE)
        for _, row in tqdm(df.iterrows(), total=len(df), desc=f'Writing {out}'):
            lbl = parse_scp_codes_single_label_with_threshold(row['scp_codes'], diag_map, thr)
            if lbl is None:
                continue
            sig = read_ecg_data(os.path.join(root, row['filename_hr']))
            if sig.shape[1] < WINDOW_SIZE:
                continue
            sig = filter_ecg(sig, FS)
            for seg in sliding_window_ecg(sig):
                buf_d.append(seg)
                buf_l.append(lbl)
            if len(buf_d) >= CHUNK_SIZE:
                db = np.stack(buf_d)
                lb = np.array(buf_l, dtype=np.int32)
                append_h5(dset, lset, db, lb)
                buf_d.clear()
                buf_l.clear()
        if buf_d:
            db = np.stack(buf_d)
            lb = np.array(buf_l, dtype=np.int32)
            append_h5(dset, lset, db, lb)
    # Verify results
    with h5py.File(out, 'r') as f:
        print(f'[Check] {out}')
        print('  keys:', list(f.keys()))
        print('  input shape:', f['input'].shape)
        print('  label shape:', f['label'].shape)
        labels = f['label'][:]
        unique, counts = np.unique(labels, return_counts=True)
        print('  label distribution:')
        for u, c in zip(unique, counts):
            print(f'    class {u}: {c}')

# ------------------------------------------------------
# 8. Main function
# ------------------------------------------------------
def main():
    root = r'ECG\ptbxl'
    df = pd.read_csv(os.path.join(root, 'ptbxl_database.csv'))
    scp_map = load_diag_map(os.path.join(root, 'scp_statements.csv'))
    splits = {
        'train': df[df['strat_fold'] <= 8],
        'val':   df[df['strat_fold'] == 9],
        'test':  df[df['strat_fold'] == 10]
    }
    for name, sub in splits.items():
        out_file = os.path.join(root, f'{name}.h5')
        process_split(sub.reset_index(drop=True), root, out_file, scp_map, thr=90.0)

if __name__ == '__main__':
    main()