#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
georgia_ecg_to_h5.py  ──  Georgia 12‑lead ECG Challenge Database preprocessing

This utility converts raw ECG recordings into an HDF5 dataset of fixed‑length,
z‑scored windows, intended for **self‑supervised/contrastive pre‑training**.
"""

import os
import h5py
import numpy as np
import wfdb
import scipy.io as sio
from scipy.signal import butter, iirnotch, filtfilt


# ───────────── Filters ─────────────

def butter_bandpass(low: float, high: float, fs: float, order: int = 4):
    """Return Butterworth band‑pass filter coefficients."""
    nyq = 0.5 * fs
    return butter(order, [low / nyq, high / nyq], btype="band")


def apply_filters(
    x: np.ndarray,
    fs: int,
    low: float = 0.05,
    high: float = 100,
    notch: float = 60,
    order: int = 4,
    q: float = 30,
):
    """Band‑pass  + notch filter for 12‑lead ECG."""
    b_band, a_band = butter_bandpass(low, high, fs, order)
    x = filtfilt(b_band, a_band, x, axis=1)
    b_n, a_n = iirnotch(notch, q, fs)
    return filtfilt(b_n, a_n, x, axis=1)


# ───────────── Sliding window + Z‑score ─────────────

def sliding_windows(ecg: np.ndarray, win: int = 1024, step: int = 512):
    """Slice ECG into overlapping windows, zero‑padding the last frame if needed."""
    L = ecg.shape[1]
    segs, s = [], 0
    while s < L:
        e = s + win
        if e <= L:
            w = ecg[:, s:e]
        else:  # Zero‑pad last window
            w = np.zeros((12, win), dtype=ecg.dtype)
            w[:, : L - s] = ecg[:, s:L]
        segs.append(w[None, ...])
        s += step
    return np.concatenate(segs, axis=0)


def zscore(w: np.ndarray, eps: float = 1e-8):
    """Channel‑wise z‑normalisation."""
    m = w.mean(2, keepdims=True)
    sd = w.std(2, keepdims=True) + eps
    return (w - m) / sd


# ───────────── Main ─────────────

def convert_georgia_to_h5(
    in_dir: str,
    out_path: str,
    win: int = 1024,
    step: int = 512,
):
    """
    Walk through `in_dir`, read *.mat/ *.hea pairs, and write windows to HDF5.
    The resulting file contains a single dataset `data` with shape (N, 12, win),
    ready for pre‑training pipelines.
    """
    with h5py.File(out_path, "w") as f:
        dset = f.create_dataset(
            "data",
            shape=(0, 12, win),
            maxshape=(None, 12, win),
            dtype=np.float32,
            compression="gzip",
        )
        cur = 0

        for root, _, files in os.walk(in_dir):
            for file in files:
                if not file.endswith(".mat"):
                    continue
                rec = os.path.join(root, file[:-4])  # strip extension

                # ---------- Read ECG ----------
                try:
                    sig, fld = wfdb.rdsamp(rec)  # .mat + .hea
                    sig = sig.T  # (12, N)
                    fs = int(fld["fs"])
                except Exception:
                    # Some records have no .hea; fall back to .mat only
                    mdict = sio.loadmat(rec + ".mat", squeeze_me=True)
                    sig = mdict["val"]  # typical field name (12, N)
                    if sig.shape[0] != 12:  # (N, 12) -> (12, N)
                        sig = sig.T
                    fs = 500  # Georgia DB sampling rate

                ecg = np.nan_to_num(sig, nan=0.0)  # replace NaNs

                # ---------- Pre‑processing ----------
                ecg = apply_filters(ecg, fs)
                win_arr = zscore(sliding_windows(ecg, win, step)).astype(np.float32)
                n = win_arr.shape[0]

                # ---------- Write to HDF5 ----------
                dset.resize(cur + n, axis=0)
                dset[cur : cur + n] = win_arr
                cur += n
                print(f"✓ {os.path.basename(rec):10s}  +{n:4d} windows")

        print(f"\n✅ Done. Final dataset shape = {dset.shape}")


# ───────────── CLI ─────────────
if __name__ == "__main__":
    INPUT_DIR = r"ECG\Georgia"  # directory with *.mat / *.hea
    OUTPUT_H5 = r"ECG\Georgia\georgia.h5"
    convert_georgia_to_h5(INPUT_DIR, OUTPUT_H5, win=1024, step=512)
