"""
Build processed CSV (nhanes_cost_table.csv) from raw XPT files.

Computes signed Δ from systolic blood pressure (SBP) threshold for
hypertension classification.

Prerequisites:
  Download these files from CDC NHANES 2013-2014 to data/nhanes/raw/:
  - DEMO_H.XPT (demographics)
  - BPX_H.XPT (blood pressure)
  - BMX_H.XPT (body measures)

Usage: python -m src.data.preprocess_nhanes
"""
import pandas as pd
import numpy as np
from pathlib import Path

# File paths
RAW_DIR = Path("data/nhanes/raw")
DEMO_XPT = RAW_DIR / "DEMO_H.XPT"
BPX_XPT = RAW_DIR / "BPX_H.XPT"
BMX_XPT = RAW_DIR / "BMX_H.XPT"
OUT_CSV = "data/nhanes_cost_table.csv"

# Hypertension threshold (mmHg)
# 130: 2017 ACC/AHA guideline
# 140: older JNC7 guideline
SBP_THRESHOLD = 130


def main():
    print(f"Loading NHANES 2013-2014 XPT files from {RAW_DIR}...")

    # Check files exist
    for f in [DEMO_XPT, BPX_XPT, BMX_XPT]:
        if not f.exists():
            raise FileNotFoundError(
                f"Missing {f.name}\n"
                f"Download from: https://wwwn.cdc.gov/nchs/nhanes/continuousnhanes/overview.aspx?BeginYear=2013"
            )

    # Read XPT files
    print("  Reading demographics...")
    demo = pd.read_sas(DEMO_XPT)

    print("  Reading blood pressure...")
    bpx = pd.read_sas(BPX_XPT)

    print("  Reading body measures...")
    bmx = pd.read_sas(BMX_XPT)

    print(f"  Loaded {len(demo):,} participants from demographics")

    # Merge on SEQN (participant ID)
    df = demo.merge(bpx, on="SEQN", how="inner")
    df = df.merge(bmx, on="SEQN", how="inner")

    print(f"  After merging: {len(df):,} participants with complete data")

    # Calculate average SBP (exclude first reading per protocol, use readings 2-4)
    # BPXSY1, BPXSY2, BPXSY3, BPXSY4 = systolic readings
    sbp_cols = ["BPXSY2", "BPXSY3", "BPXSY4"]

    # Keep only rows with at least one valid SBP reading
    df = df[df[sbp_cols].notna().any(axis=1)].copy()

    # Compute average SBP (excluding NaNs)
    df["sbp_avg"] = df[sbp_cols].mean(axis=1, skipna=True)

    print(f"  After filtering for valid BP: {len(df):,} participants")

    # Compute signed delta from threshold
    df["delta_signed"] = df["sbp_avg"] - SBP_THRESHOLD
    df["abs_delta"] = df["delta_signed"].abs()

    # Binary label (1 = hypertensive, 0 = normotensive)
    df["y_star"] = (df["delta_signed"] >= 0).astype(int)

    # Select key features for output
    # RIDAGEYR = age in years
    # RIAGENDR = gender (1=male, 2=female)
    # RIDRETH3 = race/ethnicity
    # BMXBMI = BMI
    output_cols = [
        "SEQN",           # Participant ID
        "RIDAGEYR",       # Age
        "RIAGENDR",       # Gender
        "RIDRETH3",       # Race/ethnicity
        "BMXBMI",         # BMI
        "sbp_avg",        # Average SBP
        "delta_signed",   # Signed distance from threshold
        "abs_delta",      # Absolute distance (cost weight)
        "y_star"          # Binary label
    ]

    df_out = df[output_cols].copy()

    # Drop rows with missing features
    df_out = df_out.dropna()

    print(f"  After dropping missing values: {len(df_out):,} complete cases")

    # Save
    df_out.to_csv(OUT_CSV, index=False)

    print(f"\n✓ Wrote {OUT_CSV}")
    print(f"  Rows: {len(df_out):,}")
    print(f"  Mean SBP: {df_out['sbp_avg'].mean():.1f} mmHg")
    print(f"  Mean |Δ|: {df_out['abs_delta'].mean():.1f} mmHg")
    print(f"  Max |Δ|: {df_out['abs_delta'].max():.1f} mmHg")
    print(f"  Positive class (hypertensive, SBP≥{SBP_THRESHOLD}): {df_out['y_star'].mean():.1%}")
    print(f"  Mean age: {df_out['RIDAGEYR'].mean():.1f} years")
    print(f"  Mean BMI: {df_out['BMXBMI'].mean():.1f}")


if __name__ == "__main__":
    main()
