#!/usr/bin/env python3
"""
MIMIC‑III ICU mortality preprocessing
====================================

This script reproduces the complete data preparation pipeline from the
`Pacmed/aisg_2019` GitHub repository.  The original project consists
of several standalone modules (`patient_characteristics.py`,
`vital_signs.py`, `lab_results.py`, and `concat.py`) that each write
intermediate files to ``./data/interim`` and ultimately generate the
processed feature matrices and labels used to train the deterministic
and Bayesian neural networks.  Combining these disparate steps into a
single Python file simplifies reproducibility: point the script to the
MIMIC‑III v1.4 data directory and it will write the same CSVs expected
by the original training notebooks.

Key behaviours reproduced from the upstream code:

* **Patient demographics** are assembled by merging the ``PATIENTS``,
  ``ADMISSIONS`` and ``ICUSTAYS`` tables.  Age is computed from date
  of birth and hospital admission time (capped at 90 years to respect
  MIMIC’s de‑identification) and added as a feature.  A boolean
  ``EXPIRED_THIS_ICUSTAY`` column flags deaths occurring during the
  last ICU stay of a hospital admission.  Additional static features
  include gender, ICU length of stay (``LOS``) and the number of
  hours between hospital admission and ICU admission.

* **Vital sign aggregation** uses the mapping in
  ``chartevents_numerical_features.tsv`` to extract a small set of
  numeric ITEMIDs from the ``CHARTEVENTS`` table.  For each ICU stay
  we compute the **mean** and **standard deviation** of each vital
  sign.  Columns are named with the aggregation first (e.g.
  ``mean_heartrate`` instead of ``heartrate_mean``) to match the
  original code.  Invasive and cuff blood pressure measurements are
  combined: a new ``combined_bp`` feature uses the arterial value when
  present, otherwise the cuff value.  The non‑invasive columns are
  dropped but the arterial columns remain, just as in the repository.

* **Laboratory test aggregation** follows the same pattern using
  ``labevents_numerical_features.tsv`` to select relevant lab ITEMIDs
  from ``LABEVENTS``.  Lab results are restricted to times within the
  ICU stay (between ``INTIME`` and ``OUTTIME``).  Again we compute
  per‑stay means and standard deviations and name columns with the
  aggregation first.

* **Concatenation, cleaning and scaling** joins the interim tables on
  ``SUBJECT_ID``, ``HADM_ID`` and ``ICUSTAY_ID``.  Gender is
  one‑hot encoded (dropping the first category) and newborn
  admissions (``ADMISSION_TYPE == 'NEWBORN'``) are set aside as an
  out‑of‑domain test set.  The remaining data undergoes outlier
  removal using an 8×IQR rule over all feature columns, then rows
  with implausibly low mean combined blood pressure or non‑positive
  pre‑ICU hospital time are dropped.  Standard deviation columns with
  missing values are filled with zero.  Finally the dataset is split
  into a 90/10 train/test split (no fixed random seed is used in the
  original), median‑imputed and min‑max scaled.  Newborn rows are
  transformed with the same imputer and scaler.  All processed and
  raw matrices as well as labels are saved to ``./data/processed``.

Run this script from the repository root or a directory containing
``data/chartevents_numerical_features.tsv`` and
``data/labevents_numerical_features.tsv``.  Example invocation:

```
python mimic_preprocess_single.py --data_dir /path/to/mimic-iii-clinical-database-1.4
```

The script writes the following files:

```
data/interim/patient_characteristics.csv
data/interim/vital_signs.csv
data/interim/lab_results.csv
data/processed/X_train_raw.csv
data/processed/X_test_raw.csv
data/processed/X_train_processed.csv
data/processed/X_test_processed.csv
data/processed/y_train.csv
data/processed/y_test.csv
data/processed/X_newborns.csv
data/processed/y_newborns.csv
```

This file is self‑contained and does not depend on the pacmagic or
dask packages used in the upstream repository.  It should therefore
run in a standard Python environment with ``pandas`` and
``scikit‑learn`` installed.
"""


from __future__ import annotations

import argparse
import pandas as pd
import numpy as np
import duckdb
from pathlib import Path
from typing import List, Tuple

from sklearn.impute import SimpleImputer
from sklearn.preprocessing import MinMaxScaler
from sklearn.model_selection import train_test_split


def find_data_file(data_dir: Path, base_name: str) -> Tuple[str, str]:
    """Locate a CSV or CSV.GZ file in ``data_dir``.

    Returns the full path to the file and the compression type (``'gzip'`` or
    ``''``).  Raises FileNotFoundError if neither file is found.
    """
    gz_path = data_dir / f"{base_name}.csv.gz"
    csv_path = data_dir / f"{base_name}.csv"
    if gz_path.exists():
        return str(gz_path), 'gzip'
    if csv_path.exists():
        return str(csv_path), ''
    raise FileNotFoundError(f"Could not locate {base_name}.csv(.gz) in {data_dir}")


def read_mapping(tsv_name: str) -> pd.DataFrame:
    """Load a TSV mapping file relative to this script or the current dir."""
    script_dir = Path(__file__).resolve().parent
    default_path = script_dir / 'data' / tsv_name
    if not default_path.exists():
        default_path = Path('data') / tsv_name
    df = pd.read_csv(default_path, sep='\t')
    return df


def build_patient_characteristics(data_dir: Path, interim_dir: Path) -> pd.DataFrame:
    """Compute patient demographics and static features using pandas.

    Mirrors ``data/patient_characteristics.py`` from the upstream repo.
    The output is written to ``interim_dir/patient_characteristics.csv`` and
    returned as a DataFrame.
    """
    # Read core tables
    patients = pd.read_csv(
        next(Path(data_dir).glob('PATIENTS.csv*')),  # supports .csv or .csv.gz
        usecols=["SUBJECT_ID", "GENDER", "DOB"],
        parse_dates=["DOB"],
        low_memory=False,
    )
    admissions = pd.read_csv(
        next(Path(data_dir).glob('ADMISSIONS.csv*')),
        usecols=[
            "SUBJECT_ID",
            "HADM_ID",
            "ADMISSION_TYPE",
            "ADMITTIME",
            "HOSPITAL_EXPIRE_FLAG",
            "ADMISSION_LOCATION",
            "DISCHARGE_LOCATION",
            "ETHNICITY",
            "DIAGNOSIS",
        ],
        parse_dates=["ADMITTIME"],
        low_memory=False,
    )
    icu = pd.read_csv(
        next(Path(data_dir).glob('ICUSTAYS.csv*')),
        usecols=["SUBJECT_ID", "HADM_ID", "ICUSTAY_ID", "INTIME", "LOS"],
        parse_dates=["INTIME"],
        low_memory=False,
    )

    # Merge tables
    df = patients.merge(admissions, on="SUBJECT_ID", how="outer")
    df = df.merge(icu, on=["SUBJECT_ID", "HADM_ID"], how="outer")

    # Identify final ICU stay per admission
    final_intime = df.groupby("HADM_ID")["INTIME"].transform("max")
    df["FINAL_ICUSTAY"] = df["INTIME"] == final_intime
    df["EXPIRED_THIS_ICUSTAY"] = df["FINAL_ICUSTAY"] & df["HOSPITAL_EXPIRE_FLAG"].fillna(0).astype(bool)

    # Compute age at first admission
    first_adm = admissions.sort_values("ADMITTIME").drop_duplicates("SUBJECT_ID")[["SUBJECT_ID", "ADMITTIME"]]
    age_df = first_adm.merge(patients, on="SUBJECT_ID", how="left")
    age = (age_df["ADMITTIME"] - age_df["DOB"]).dt.total_seconds() / (365.2425 * 24 * 3600)
    age = age.mask(age < -200, 90.0)
    age_df = pd.DataFrame({"SUBJECT_ID": age_df["SUBJECT_ID"], "age": age})
    df = df.merge(age_df, on="SUBJECT_ID", how="left")

    # Time before ICU admission in hours
    df["time_at_hosp_pre_ic_admission"] = (
        (df["INTIME"] - df["ADMITTIME"]).dt.total_seconds() / 3600.0
    )

    # Select columns to keep
    cols = [
        "SUBJECT_ID",
        "HADM_ID",
        "ICUSTAY_ID",
        "EXPIRED_THIS_ICUSTAY",
        "GENDER",
        "age",
        "LOS",
        "time_at_hosp_pre_ic_admission",
        "ADMISSION_TYPE",
        "ADMISSION_LOCATION",
        "DISCHARGE_LOCATION",
        "ETHNICITY",
        "DIAGNOSIS",
    ]
    df = df[cols].copy()

    interim_dir.mkdir(parents=True, exist_ok=True)
    df.to_csv(interim_dir / "patient_characteristics.csv", index=False)
    return df


def build_vital_signs_duckdb(data_dir: Path, interim_dir: Path) -> pd.DataFrame:
    """Aggregate vital signs using DuckDB to limit memory usage."""
    ce_map = read_mapping('chartevents_numerical_features.tsv')
    # Locate CHARTEVENTS file and determine compression
    ce_path, _ = find_data_file(data_dir, 'CHARTEVENTS')

    con = duckdb.connect()
    # Register the mapping table as an in‑memory relation
    con.register('mapping_ce', ce_map)

    # Build SQL query to aggregate mean and std per feature per stay.
    # DuckDB automatically handles gzipped files.
    query = f"""
    WITH ce AS (
        SELECT CAST(SUBJECT_ID AS BIGINT) AS SUBJECT_ID,
               CAST(HADM_ID AS BIGINT)    AS HADM_ID,
               CAST(ICUSTAY_ID AS BIGINT) AS ICUSTAY_ID,
               CAST(ITEMID AS BIGINT)     AS ITEMID,
               CAST(VALUENUM AS DOUBLE)   AS VALUE
        FROM read_csv_auto('{ce_path}')
    )
    SELECT ce.SUBJECT_ID, ce.HADM_ID, ce.ICUSTAY_ID, m.NAME,
           AVG(ce.VALUE) AS mean,
           STDDEV_SAMP(ce.VALUE) AS std
    FROM ce
    JOIN mapping_ce m ON ce.ITEMID = m.ITEMID
    WHERE ce.VALUE IS NOT NULL
    GROUP BY ce.SUBJECT_ID, ce.HADM_ID, ce.ICUSTAY_ID, m.NAME
    ORDER BY ce.SUBJECT_ID, ce.HADM_ID, ce.ICUSTAY_ID, m.NAME
    ;
    """
    vital = con.execute(query).df()
    con.close()

    # Pivot to wide format: columns become 'mean_feature' and 'std_feature'
    wide = vital.pivot_table(index=["SUBJECT_ID", "HADM_ID", "ICUSTAY_ID"],
                            columns='NAME', values=['mean', 'std'])
    wide.columns = ['_'.join(col) for col in wide.columns]
    wide = wide.reset_index()

    # Combine arterial and cuff blood pressures
    for col in list(wide.columns):
        if 'arterial_bp' in col:
            combined = col.replace('arterial_bp', 'combined_bp')
            ni = col.replace('arterial_bp', 'ni_bp')
            if combined not in wide.columns:
                # Use arterial when present, otherwise use cuff
                wide[combined] = wide[col].where(~wide[col].isna(), wide.get(ni))
            if ni in wide.columns:
                wide = wide.drop(columns=ni)

    interim_dir.mkdir(parents=True, exist_ok=True)
    wide.to_csv(interim_dir / "vital_signs.csv", index=False)
    return wide


def build_lab_results_duckdb(data_dir: Path, interim_dir: Path) -> pd.DataFrame:
    """Aggregate lab results using DuckDB to limit memory usage."""
    le_map = read_mapping('labevents_numerical_features.tsv')
    # Locate LABEVENTS and ICUSTAYS files
    le_path, _ = find_data_file(data_dir, 'LABEVENTS')
    icu_path, _ = find_data_file(data_dir, 'ICUSTAYS')

    con = duckdb.connect()
    con.register('mapping_le', le_map)

    # Construct query: join LABEVENTS with ICUSTAYS on subject & adm, restrict times
    query = f"""
    WITH le AS (
        SELECT CAST(SUBJECT_ID AS BIGINT) AS SUBJECT_ID,
               CAST(HADM_ID AS BIGINT)    AS HADM_ID,
               CAST(ITEMID AS BIGINT)     AS ITEMID,
               CAST(VALUENUM AS DOUBLE)   AS VALUE,
               CAST(CHARTTIME AS TIMESTAMP) AS CHARTTIME
        FROM read_csv_auto('{le_path}')
    ),
    icu AS (
        SELECT CAST(SUBJECT_ID AS BIGINT) AS SUBJECT_ID,
               CAST(HADM_ID AS BIGINT)    AS HADM_ID,
               CAST(ICUSTAY_ID AS BIGINT) AS ICUSTAY_ID,
               CAST(INTIME AS TIMESTAMP)  AS INTIME,
               CAST(OUTTIME AS TIMESTAMP) AS OUTTIME
        FROM read_csv_auto('{icu_path}')
    )
    SELECT le.SUBJECT_ID, le.HADM_ID, icu.ICUSTAY_ID, m.NAME,
           AVG(le.VALUE) AS mean,
           STDDEV_SAMP(le.VALUE) AS std
    FROM le
    JOIN icu ON le.SUBJECT_ID = icu.SUBJECT_ID AND le.HADM_ID = icu.HADM_ID
    JOIN mapping_le m ON le.ITEMID = m.ITEMID
    WHERE le.VALUE IS NOT NULL
      AND le.CHARTTIME >= icu.INTIME
      AND le.CHARTTIME <= icu.OUTTIME
    GROUP BY le.SUBJECT_ID, le.HADM_ID, icu.ICUSTAY_ID, m.NAME
    ORDER BY le.SUBJECT_ID, le.HADM_ID, icu.ICUSTAY_ID, m.NAME
    ;
    """
    labs = con.execute(query).df()
    con.close()

    # Pivot to wide format
    wide = labs.pivot_table(index=["SUBJECT_ID", "HADM_ID", "ICUSTAY_ID"],
                            columns='NAME', values=['mean','std'])
    wide.columns = ['_'.join(col) for col in wide.columns]
    wide = wide.reset_index()

    interim_dir.mkdir(parents=True, exist_ok=True)
    wide.to_csv(interim_dir / "lab_results.csv", index=False)
    return wide


def remove_outliers(df: pd.DataFrame, feature_cols: List[str], n_iqrs: float = 8.0) -> pd.DataFrame:
    """Remove extreme outlier rows using an IQR filter (same as upstream code)."""
    df_use = df[feature_cols]
    q1 = df_use.quantile(0.25)
    q3 = df_use.quantile(0.75)
    iqr = q3 - q1
    mask = ~((df_use < (q1 - n_iqrs * iqr)) | (df_use > (q3 + n_iqrs * iqr))).any(axis=1)
    return df.loc[mask].copy()


def concat_and_finalize(interim_dir: Path, processed_dir: Path) -> None:
    """Merge interim tables, clean, split and scale features."""
    lab_results = pd.read_csv(interim_dir / "lab_results.csv")
    vital_signs = pd.read_csv(interim_dir / "vital_signs.csv")
    patient_char = pd.read_csv(interim_dir / "patient_characteristics.csv")

    # Merge
    df = lab_results.merge(patient_char, on=["SUBJECT_ID","HADM_ID","ICUSTAY_ID"], how="inner")
    df = df.merge(vital_signs, on=["SUBJECT_ID","HADM_ID","ICUSTAY_ID"], how="inner")

    # One‑hot encode gender
    if "GENDER" in df.columns:
        df = pd.get_dummies(df, columns=["GENDER"], drop_first=True, dummy_na=False)

    # Separate newborns
    df_newborns = df[df["ADMISSION_TYPE"] == "NEWBORN"].copy()
    df = df[df["ADMISSION_TYPE"] != "NEWBORN"].copy()

    # Exclude non‑feature columns
    exclude_cols = {
        "SUBJECT_ID", "HADM_ID", "ICUSTAY_ID", "EXPIRED_THIS_ICUSTAY",
        "ADMISSION_TYPE", "ADMISSION_LOCATION", "DISCHARGE_LOCATION", "ETHNICITY", "DIAGNOSIS"
    }
    feature_cols = [c for c in df.columns if c not in exclude_cols]

    # Remove outliers
    df = remove_outliers(df, feature_cols, n_iqrs=8.0)

    # Plausibility filters
    for col in ["mean_combined_bp_dia", "mean_combined_bp_sys"]:
        if col in df.columns:
            df = df[df[col] > 10]
    if "time_at_hosp_pre_ic_admission" in df.columns:
        df = df[df["time_at_hosp_pre_ic_admission"] > 0]

    # Fill std NaNs with zero and coerce to numeric
    for col in feature_cols + ["EXPIRED_THIS_ICUSTAY"]:
        if col not in df.columns:
            continue
        if "std" in col:
            df[col] = df[col].fillna(0)
        df[col] = pd.to_numeric(df[col], errors="raise")

    # Train/test split (10% test).  Upstream code uses no fixed seed.
    X_train_raw, X_test_raw = train_test_split(df, test_size=0.1)
    y_train = X_train_raw["EXPIRED_THIS_ICUSTAY"].astype(int)
    y_test = X_test_raw["EXPIRED_THIS_ICUSTAY"].astype(int)
    X_train_processed = X_train_raw[feature_cols].copy()
    X_test_processed = X_test_raw[feature_cols].copy()

    # Impute and scale
    imputer = SimpleImputer(strategy='median')
    scaler = MinMaxScaler()
    X_train_processed.loc[:] = imputer.fit_transform(X_train_processed)
    X_test_processed.loc[:] = imputer.transform(X_test_processed)
    X_train_processed.loc[:] = scaler.fit_transform(X_train_processed)
    X_test_processed.loc[:] = scaler.transform(X_test_processed)

    # Transform newborns
    if not df_newborns.empty:
        X_newborns = df_newborns[feature_cols].copy()
        X_newborns.loc[:] = imputer.transform(X_newborns)
        X_newborns.loc[:] = scaler.transform(X_newborns)
        y_newborns = df_newborns["EXPIRED_THIS_ICUSTAY"].astype(int)
    else:
        X_newborns = pd.DataFrame(columns=feature_cols)
        y_newborns = pd.Series(dtype=int)

    # Ensure processed directory exists
    processed_dir.mkdir(parents=True, exist_ok=True)

    # Write outputs
    X_train_raw.to_csv(processed_dir / "X_train_raw.csv", index=False)
    X_test_raw.to_csv(processed_dir / "X_test_raw.csv", index=False)
    X_train_processed.to_csv(processed_dir / "X_train_processed.csv", index=False)
    X_test_processed.to_csv(processed_dir / "X_test_processed.csv", index=False)
    y_train.to_csv(processed_dir / "y_train.csv", index=False, header=True)
    y_test.to_csv(processed_dir / "y_test.csv", index=False, header=True)
    y_newborns.to_csv(processed_dir / "y_newborns.csv", index=False, header=True)
    X_newborns.to_csv(processed_dir / "X_newborns.csv", index=False)


def main() -> None:
    parser = argparse.ArgumentParser(description="Preprocess MIMIC‑III data to match the Pacmed/aisg_2019 dataset")
    parser.add_argument(
        "--data_dir",
        type=str,
        required=True,
        help="Path to the directory containing MIMIC‑III v1.4 CSV files (compressed or uncompressed)",
    )
    parser.add_argument(
        "--interim_dir",
        type=str,
        default="data2/interim",
        help="Directory to write interim CSV files (default: ./data/interim)",
    )
    parser.add_argument(
        "--processed_dir",
        type=str,
        default="data2/processed",
        help="Directory to write processed CSV files (default: ./data/processed)",
    )
    args = parser.parse_args()
    data_dir = Path(args.data_dir)
    interim_dir = Path(args.interim_dir)
    processed_dir = Path(args.processed_dir)

    # Run pipeline
    build_patient_characteristics(data_dir, interim_dir)
    build_vital_signs_duckdb(data_dir, interim_dir)
    build_lab_results_duckdb(data_dir, interim_dir)
    concat_and_finalize(interim_dir, processed_dir)


if __name__ == '__main__':
    main()
