#!/usr/bin/env python3
"""
NHANES Data Preprocessing Pipeline
===================================
Processes NHANES survey data (2007-2018) for biological age and mortality risk prediction.

This script:
1. Loads biomarker, demographic, body measurement, blood pressure, and mortality data
2. Merges all datasets on SEQN (unique participant ID)
3. Engineers derived features (NLR, WHR, MAP, PP, etc.)
4. Applies data cleaning (winsorization, log transforms, imputation, standardization)
5. Creates stratified train/val/test splits
6. Saves processed datasets in parquet format

Author: Algorithm Implementation Agent
Created: 2026-01-12
"""

import os
import json
import logging
from datetime import datetime
from pathlib import Path
from typing import Dict, List, Tuple, Optional, Any

import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler

# Configure logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)


class TrainingLogger:
    """Handles logging to both console and file."""

    def __init__(self, log_path: str):
        self.log_path = log_path
        self.logs = []

    def log(self, message: str, level: str = "INFO"):
        timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
        log_entry = f"[{timestamp}] [{level}] {message}"
        self.logs.append(log_entry)
        print(log_entry)

    def save(self):
        with open(self.log_path, 'w') as f:
            f.write('\n'.join(self.logs))


class NHANESDataLoader:
    """Loads and merges NHANES datasets from .xpt and .dat files."""

    # Cycle to year mapping
    CYCLE_YEARS = {
        'E': '2007-2008',
        'F': '2009-2010',
        'G': '2011-2012',
        'H': '2013-2014',
        'I': '2015-2016',
        'J': '2017-2018',
    }

    # Mortality file mapping
    MORT_FILES = {
        'E': 'NHANES_2007_2008_MORT_2019_PUBLIC.dat',
        'F': 'NHANES_2009_2010_MORT_2019_PUBLIC.dat',
        'G': 'NHANES_2011_2012_MORT_2019_PUBLIC.dat',
        'H': 'NHANES_2013_2014_MORT_2019_PUBLIC.dat',
        'I': 'NHANES_2015_2016_MORT_2019_PUBLIC.dat',
        'J': 'NHANES_2017_2018_MORT_2019_PUBLIC.dat',
    }

    def __init__(self, data_dir: str, logger: TrainingLogger):
        self.data_dir = Path(data_dir)
        self.logger = logger

    def load_xpt_data(self, file_pattern: str, cycles: List[str]) -> pd.DataFrame:
        """
        Load and concatenate .xpt files across cycles.

        Args:
            file_pattern: Pattern like 'BIOPRO_{cycle}.xpt'
            cycles: List of cycle codes ['E', 'F', 'G', 'H', 'I', 'J']

        Returns:
            Concatenated DataFrame with 'cycle' column
        """
        dfs = []
        for cycle in cycles:
            filepath = self.data_dir / file_pattern.format(cycle=cycle)
            if filepath.exists():
                df = pd.read_sas(filepath)
                df['cycle'] = cycle
                dfs.append(df)
                self.logger.log(f"  Loaded {filepath.name}: {len(df)} rows, {len(df.columns)} columns")
            else:
                self.logger.log(f"  Warning: {filepath.name} not found", "WARNING")

        if dfs:
            return pd.concat(dfs, ignore_index=True)
        return pd.DataFrame()

    def load_mortality_data(self, cycles: List[str]) -> pd.DataFrame:
        """
        Parse fixed-width mortality .dat files.

        Returns:
            DataFrame with SEQN, ELIGSTAT, MORTSTAT, PERMTH_EXM, cycle
        """
        self.logger.log("Loading mortality data...")

        # CDC official column specifications (0-indexed for pandas)
        colspecs = [
            (0, 6),    # SEQN
            (14, 15),  # ELIGSTAT
            (15, 16),  # MORTSTAT
            (16, 19),  # UCOD_LEADING
            (42, 45),  # PERMTH_INT
            (45, 48),  # PERMTH_EXM
        ]
        names = ['SEQN', 'ELIGSTAT', 'MORTSTAT', 'UCOD_LEADING', 'PERMTH_INT', 'PERMTH_EXM']

        dfs = []
        for cycle in cycles:
            filepath = self.data_dir / self.MORT_FILES[cycle]
            if filepath.exists():
                df = pd.read_fwf(filepath, colspecs=colspecs, names=names, na_values=['.', ''])
                df['cycle'] = cycle
                dfs.append(df)
                self.logger.log(f"  Loaded {filepath.name}: {len(df)} rows")
            else:
                self.logger.log(f"  Warning: {filepath.name} not found", "WARNING")

        if dfs:
            return pd.concat(dfs, ignore_index=True)
        return pd.DataFrame()

    def load_demographics(self, cycles: List[str]) -> pd.DataFrame:
        """Load demographic data (age, sex) from DEMO files."""
        self.logger.log("Loading demographic data...")
        df = self.load_xpt_data('DEMO_{cycle}.xpt', cycles)

        if df.empty:
            self.logger.log("No demographic data found!", "ERROR")
            return df

        # Select key columns
        cols = ['SEQN', 'RIDAGEYR', 'RIAGENDR', 'cycle']
        available = [c for c in cols if c in df.columns]
        df = df[available].copy()

        # Rename for clarity
        df = df.rename(columns={
            'RIDAGEYR': 'age',
            'RIAGENDR': 'sex'
        })

        # Keep first row per SEQN if duplicates exist
        df = df.groupby('SEQN').first().reset_index()
        self.logger.log(f"Demographics: {len(df)} unique participants")
        return df

    def load_biochemistry(self, cycles: List[str]) -> pd.DataFrame:
        """Load biochemistry profile data (BIOPRO)."""
        self.logger.log("Loading biochemistry data...")
        df = self.load_xpt_data('BIOPRO_{cycle}.xpt', cycles)

        if df.empty:
            return df

        # Keep first row per SEQN
        df = df.groupby('SEQN').first().reset_index()
        self.logger.log(f"Biochemistry: {len(df)} unique participants")
        return df

    def load_cbc(self, cycles: List[str]) -> pd.DataFrame:
        """Load complete blood count data (CBC)."""
        self.logger.log("Loading CBC data...")
        df = self.load_xpt_data('CBC_{cycle}.xpt', cycles)

        if df.empty:
            return df

        df = df.groupby('SEQN').first().reset_index()
        self.logger.log(f"CBC: {len(df)} unique participants")
        return df

    def load_body_measures(self, cycles: List[str]) -> pd.DataFrame:
        """Load body measurement data (BMX)."""
        self.logger.log("Loading body measurement data...")
        df = self.load_xpt_data('BMX_{cycle}.xpt', cycles)

        if df.empty:
            return df

        df = df.groupby('SEQN').first().reset_index()
        self.logger.log(f"Body measures: {len(df)} unique participants")
        return df

    def load_blood_pressure(self, cycles: List[str]) -> pd.DataFrame:
        """Load blood pressure data (BPX)."""
        self.logger.log("Loading blood pressure data...")
        df = self.load_xpt_data('BPX_{cycle}.xpt', cycles)

        if df.empty:
            return df

        df = df.groupby('SEQN').first().reset_index()
        self.logger.log(f"Blood pressure: {len(df)} unique participants")
        return df

    def load_crp(self, cycles: List[str]) -> pd.DataFrame:
        """Load CRP data, unifying standard CRP and high-sensitivity CRP."""
        self.logger.log("Loading CRP data...")

        dfs = []

        # Standard CRP (cycles E, F)
        for cycle in ['E', 'F']:
            if cycle in cycles:
                filepath = self.data_dir / f'CRP_{cycle}.xpt'
                if filepath.exists():
                    df = pd.read_sas(filepath)
                    df['cycle'] = cycle
                    if 'LBXCRP' in df.columns:
                        df = df.rename(columns={'LBXCRP': 'CRP'})
                    dfs.append(df)
                    self.logger.log(f"  Loaded CRP_{cycle}.xpt: {len(df)} rows")

        # High-sensitivity CRP (cycles I, J)
        for cycle in ['I', 'J']:
            if cycle in cycles:
                filepath = self.data_dir / f'HSCRP_{cycle}.xpt'
                if filepath.exists():
                    df = pd.read_sas(filepath)
                    df['cycle'] = cycle
                    if 'LBXHSCRP' in df.columns:
                        df = df.rename(columns={'LBXHSCRP': 'CRP'})
                    dfs.append(df)
                    self.logger.log(f"  Loaded HSCRP_{cycle}.xpt: {len(df)} rows")

        if dfs:
            result = pd.concat(dfs, ignore_index=True)
            result = result.groupby('SEQN').first().reset_index()
            self.logger.log(f"CRP: {len(result)} unique participants")
            return result

        return pd.DataFrame()


class NHANESFeatureEngineer:
    """Computes derived features from raw NHANES data."""

    def __init__(self, logger: TrainingLogger):
        self.logger = logger

    def compute_all_features(self, df: pd.DataFrame) -> pd.DataFrame:
        """Compute all derived features."""
        df = df.copy()

        df = self.compute_hematology_features(df)
        df = self.compute_body_composition(df)
        df = self.compute_blood_pressure_features(df)
        df = self.compute_metabolic_features(df)

        return df

    def compute_hematology_features(self, df: pd.DataFrame) -> pd.DataFrame:
        """Compute neutrophil-to-lymphocyte ratio (NLR)."""
        if 'LBDNENO' in df.columns and 'LBDLYMNO' in df.columns:
            lymphocyte = df['LBDLYMNO'].replace(0, np.nan)
            df['NLR'] = df['LBDNENO'] / lymphocyte
            self.logger.log("Computed NLR (Neutrophil-Lymphocyte Ratio)")

        return df

    def compute_body_composition(self, df: pd.DataFrame) -> pd.DataFrame:
        """Compute waist-hip ratio (WHR)."""
        if 'BMXWAIST' in df.columns and 'BMXHIP' in df.columns:
            hip = df['BMXHIP'].replace(0, np.nan)
            df['WHR'] = df['BMXWAIST'] / hip
            self.logger.log("Computed WHR (Waist-Hip Ratio)")

        return df

    def compute_blood_pressure_features(self, df: pd.DataFrame) -> pd.DataFrame:
        """Compute mean BP, MAP, and pulse pressure."""
        # Average systolic BP across available readings
        sbp_cols = [c for c in df.columns if c.startswith('BPXSY') and c[-1].isdigit()]
        dbp_cols = [c for c in df.columns if c.startswith('BPXDI') and c[-1].isdigit()]

        if sbp_cols:
            df['SBP_mean'] = df[sbp_cols].mean(axis=1, skipna=True)
            self.logger.log(f"Computed SBP_mean from {len(sbp_cols)} readings")

        if dbp_cols:
            df['DBP_mean'] = df[dbp_cols].mean(axis=1, skipna=True)
            self.logger.log(f"Computed DBP_mean from {len(dbp_cols)} readings")

        # Mean Arterial Pressure: MAP = DBP + (SBP - DBP) / 3
        if 'SBP_mean' in df.columns and 'DBP_mean' in df.columns:
            df['MAP'] = df['DBP_mean'] + (df['SBP_mean'] - df['DBP_mean']) / 3
            self.logger.log("Computed MAP (Mean Arterial Pressure)")

            # Pulse Pressure: PP = SBP - DBP
            df['PP'] = df['SBP_mean'] - df['DBP_mean']
            self.logger.log("Computed PP (Pulse Pressure)")

        return df

    def compute_metabolic_features(self, df: pd.DataFrame) -> pd.DataFrame:
        """Compute De Ritis ratio (AST/ALT)."""
        # In NHANES: LBXSASSI = AST (SI units), LBXSATSI = ALT (SI units)
        if 'LBXSASSI' in df.columns and 'LBXSATSI' in df.columns:
            alt = df['LBXSATSI'].replace(0, np.nan)
            df['de_ritis_ratio'] = df['LBXSASSI'] / alt
            self.logger.log("Computed de_ritis_ratio (AST/ALT)")

        return df


class NHANESDataCleaner:
    """Handles data cleaning: winsorization, log transforms, imputation, standardization."""

    def __init__(self, logger: TrainingLogger):
        self.logger = logger
        self.scaler = None
        self.impute_values = {}
        self.fitted_columns = []

    def winsorize(self, series: pd.Series, lower_pct: float = 1, upper_pct: float = 99) -> pd.Series:
        """Winsorize values at given percentiles."""
        lower = np.nanpercentile(series, lower_pct)
        upper = np.nanpercentile(series, upper_pct)
        return series.clip(lower, upper)

    def log_transform(self, series: pd.Series) -> pd.Series:
        """Apply log1p transform with handling for zeros/negatives."""
        return np.log1p(series.clip(lower=0))

    def winsorize_features(self, df: pd.DataFrame, columns: List[str]) -> pd.DataFrame:
        """Winsorize specified columns."""
        df = df.copy()
        for col in columns:
            if col in df.columns:
                df[col] = self.winsorize(df[col])
        self.logger.log(f"Winsorized {len(columns)} features")
        return df

    def log_transform_features(self, df: pd.DataFrame, columns: List[str]) -> pd.DataFrame:
        """Log-transform specified columns and create new columns with _log suffix."""
        df = df.copy()
        transformed = 0
        for col in columns:
            if col in df.columns:
                df[f'{col}_log'] = self.log_transform(df[col])
                transformed += 1
        self.logger.log(f"Log-transformed {transformed} features")
        return df

    def handle_missing_values(self, df: pd.DataFrame, feature_cols: List[str],
                              fit: bool = True, max_missing_pct: float = 50.0) -> pd.DataFrame:
        """
        Impute missing values with median (fit on training data only).

        Args:
            df: DataFrame to process
            feature_cols: List of feature columns to impute
            fit: If True, compute and store median values. If False, use stored values.
            max_missing_pct: Drop features with higher missing rate
        """
        df = df.copy()

        for col in feature_cols:
            if col not in df.columns:
                continue

            missing_pct = df[col].isnull().sum() / len(df) * 100

            if missing_pct > max_missing_pct:
                self.logger.log(f"  Skipping {col}: {missing_pct:.1f}% missing (>{max_missing_pct}%)", "WARNING")
                continue

            if fit:
                self.impute_values[col] = df[col].median()

            if col in self.impute_values:
                df[col] = df[col].fillna(self.impute_values[col])

        self.logger.log(f"Imputed missing values for {len(self.impute_values)} features")
        return df

    def standardize_features(self, df: pd.DataFrame, feature_cols: List[str],
                             fit: bool = True) -> pd.DataFrame:
        """
        Apply Z-score standardization.

        Args:
            df: DataFrame to process
            feature_cols: List of feature columns to standardize
            fit: If True, fit scaler on data. If False, use stored scaler.
        """
        df = df.copy()

        # Filter to only existing numeric columns
        available_cols = [c for c in feature_cols if c in df.columns and
                          pd.api.types.is_numeric_dtype(df[c])]

        if not available_cols:
            self.logger.log("No numeric columns to standardize", "WARNING")
            return df

        if fit:
            self.scaler = StandardScaler()
            df[available_cols] = self.scaler.fit_transform(df[available_cols])
            self.fitted_columns = available_cols
            self.logger.log(f"Fitted StandardScaler on {len(available_cols)} features")
        else:
            if self.scaler is None:
                raise ValueError("Scaler not fitted. Call with fit=True first.")
            # Use only columns that were fitted
            cols_to_transform = [c for c in self.fitted_columns if c in df.columns]
            df[cols_to_transform] = self.scaler.transform(df[cols_to_transform])
            self.logger.log(f"Applied StandardScaler to {len(cols_to_transform)} features")

        return df


def construct_survival_outcomes(df: pd.DataFrame, df_mort: pd.DataFrame,
                                logger: TrainingLogger) -> pd.DataFrame:
    """
    Merge mortality data and create survival outcomes.

    Args:
        df: Main biomarker/demographic data
        df_mort: Mortality data with SEQN, ELIGSTAT, MORTSTAT, PERMTH_EXM, UCOD_LEADING

    Returns:
        DataFrame with 'time_years', 'event', and 'UCOD_LEADING' columns
    """
    logger.log("Constructing survival outcomes...")

    # Merge on SEQN (including UCOD_LEADING for cause of death analysis)
    df = df.merge(df_mort[['SEQN', 'ELIGSTAT', 'MORTSTAT', 'PERMTH_EXM', 'UCOD_LEADING']],
                  on='SEQN', how='inner')

    logger.log(f"After mortality merge: {len(df)} rows")

    # Filter for eligible participants (ELIGSTAT == 1)
    df = df[df['ELIGSTAT'] == 1].copy()
    logger.log(f"After ELIGSTAT=1 filter: {len(df)} rows")

    # Convert PERMTH_EXM to years
    df['time_years'] = df['PERMTH_EXM'] / 12.0

    # Binary event (MORTSTAT: 1=dead, 0=alive)
    df['event'] = df['MORTSTAT'].astype(int)

    # Validate: time should be positive
    df = df[df['time_years'] > 0].copy()

    logger.log(f"Survival data: {df['event'].sum()} deaths, {len(df) - df['event'].sum()} censored")

    # Log cause of death distribution
    if 'UCOD_LEADING' in df.columns:
        ucod_labels = {
            1: 'Heart disease',
            2: 'Cancer',
            3: 'Chronic lower respiratory disease',
            4: 'Accidents',
            5: 'Cerebrovascular disease',
            6: "Alzheimer's disease",
            7: 'Diabetes',
            8: 'Influenza/Pneumonia',
            9: 'Nephritis/Kidney disease',
            10: 'Other causes'
        }
        dead_df = df[df['event'] == 1]
        logger.log("Cause of death distribution:")
        for code, label in ucod_labels.items():
            count = (dead_df['UCOD_LEADING'] == code).sum()
            if count > 0:
                logger.log(f"  {code} - {label}: {count} ({count/len(dead_df)*100:.1f}%)")

    return df


def create_stratified_splits(df: pd.DataFrame,
                             train_ratio: float = 0.70,
                             val_ratio: float = 0.15,
                             test_ratio: float = 0.15,
                             random_state: int = 42,
                             logger: TrainingLogger = None) -> Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]:
    """
    Create stratified train/val/test splits.

    Stratify by event status and age groups.
    """
    if logger:
        logger.log("Creating stratified train/val/test splits...")

    # Create stratification variable
    df = df.copy()
    df['age_group'] = pd.qcut(df['age'], q=5, labels=False, duplicates='drop')
    df['strat_key'] = df['event'].astype(str) + '_' + df['age_group'].astype(str)

    # First split: train vs (val + test)
    train_df, temp_df = train_test_split(
        df,
        test_size=(val_ratio + test_ratio),
        stratify=df['strat_key'],
        random_state=random_state
    )

    # Second split: val vs test
    val_size_adj = val_ratio / (val_ratio + test_ratio)
    val_df, test_df = train_test_split(
        temp_df,
        test_size=(1 - val_size_adj),
        stratify=temp_df['strat_key'],
        random_state=random_state
    )

    # Cleanup stratification columns
    for d in [train_df, val_df, test_df]:
        d.drop(columns=['age_group', 'strat_key'], inplace=True, errors='ignore')

    if logger:
        logger.log(f"Split sizes - Train: {len(train_df)}, Val: {len(val_df)}, Test: {len(test_df)}")
        logger.log(f"Events - Train: {train_df['event'].sum()}, Val: {val_df['event'].sum()}, Test: {test_df['event'].sum()}")

    return train_df, val_df, test_df


def get_feature_columns(df: pd.DataFrame) -> List[str]:
    """Get list of feature columns (excluding identifiers and outcomes)."""
    # UCOD_LEADING is cause of death - outcome variable, not a feature
    exclude = ['SEQN', 'cycle', 'ELIGSTAT', 'MORTSTAT', 'PERMTH_EXM', 'PERMTH_INT',
               'UCOD_LEADING', 'time_years', 'event', 'age_group', 'strat_key']

    feature_cols = [c for c in df.columns if c not in exclude and
                    pd.api.types.is_numeric_dtype(df[c])]
    return feature_cols


def preprocess_nhanes_data(
    data_dir: str,
    output_dir: str,
    cycles: List[str] = None,
    min_sample_size: int = 5000,
    save_format: str = 'parquet'
) -> Dict[str, Any]:
    """
    Main preprocessing pipeline for NHANES data.

    Args:
        data_dir: Path to NHANES data directory
        output_dir: Path for output files
        cycles: List of cycle codes to include (default: all)
        min_sample_size: Minimum training set size
        save_format: Output format ('parquet' or 'csv')

    Returns:
        Dictionary with preprocessing metadata
    """
    if cycles is None:
        cycles = ['E', 'F', 'G', 'H', 'I', 'J']

    output_path = Path(output_dir)
    output_path.mkdir(parents=True, exist_ok=True)

    logger = TrainingLogger(str(output_path / 'preprocessing_log.txt'))

    logger.log("=" * 60)
    logger.log("NHANES DATA PREPROCESSING PIPELINE")
    logger.log(f"Cycles: {cycles}")
    logger.log("=" * 60)

    # Initialize loader
    loader = NHANESDataLoader(data_dir, logger)

    # Step 1: Load all data
    logger.log("\n--- STEP 1: Loading Data ---")
    df_demo = loader.load_demographics(cycles)
    df_biopro = loader.load_biochemistry(cycles)
    df_cbc = loader.load_cbc(cycles)
    df_bmx = loader.load_body_measures(cycles)
    df_bpx = loader.load_blood_pressure(cycles)
    df_crp = loader.load_crp(cycles)
    df_mort = loader.load_mortality_data(cycles)

    # Step 2: Merge all data on SEQN
    logger.log("\n--- STEP 2: Merging Datasets ---")
    df = df_demo.copy()

    merge_order = [
        (df_biopro, 'Biochemistry'),
        (df_cbc, 'CBC'),
        (df_bmx, 'Body measures'),
        (df_bpx, 'Blood pressure'),
        (df_crp, 'CRP'),
    ]

    for df_other, name in merge_order:
        if not df_other.empty:
            # Drop duplicate columns except SEQN and cycle
            cols_to_keep = [c for c in df_other.columns if c not in df.columns or c == 'SEQN']
            df_other_clean = df_other[cols_to_keep]
            df = df.merge(df_other_clean, on='SEQN', how='left')
            logger.log(f"Merged {name}: {len(df)} rows")

    logger.log(f"Combined dataset: {len(df)} rows, {len(df.columns)} columns")

    # Step 3: Merge mortality and construct survival outcomes
    logger.log("\n--- STEP 3: Constructing Survival Outcomes ---")
    df = construct_survival_outcomes(df, df_mort, logger)

    # Step 4: Feature engineering
    logger.log("\n--- STEP 4: Feature Engineering ---")
    engineer = NHANESFeatureEngineer(logger)
    df = engineer.compute_all_features(df)

    # Step 5: Data cleaning
    logger.log("\n--- STEP 5: Data Cleaning ---")
    cleaner = NHANESDataCleaner(logger)

    # Get feature columns
    feature_cols = get_feature_columns(df)
    logger.log(f"Total feature columns: {len(feature_cols)}")

    # Winsorize continuous features
    continuous_cols = [c for c in feature_cols if df[c].nunique() > 20]
    df = cleaner.winsorize_features(df, continuous_cols)

    # Log-transform skewed features
    log_cols = ['CRP', 'LBXSTR', 'LBXRDW', 'NLR', 'LBXSGTSI']
    df = cleaner.log_transform_features(df, log_cols)

    # Update feature columns after log transforms
    feature_cols = get_feature_columns(df)

    # Handle missing values (initial imputation)
    df = cleaner.handle_missing_values(df, feature_cols, fit=True)

    # Drop rows with critical missing values
    critical_cols = ['SEQN', 'time_years', 'event', 'age', 'sex']
    df = df.dropna(subset=critical_cols)
    logger.log(f"After dropping critical NaN: {len(df)} rows")

    # Step 6: Create train/val/test splits
    logger.log("\n--- STEP 6: Creating Splits ---")
    train_df, val_df, test_df = create_stratified_splits(
        df, train_ratio=0.70, val_ratio=0.15, test_ratio=0.15, logger=logger
    )

    # Validate minimum sample size
    if len(train_df) < min_sample_size:
        logger.log(f"Warning: Training set ({len(train_df)}) < minimum ({min_sample_size})", "WARNING")

    # Step 7: Standardize features (fit on training set only)
    logger.log("\n--- STEP 7: Standardizing Features ---")
    feature_cols = get_feature_columns(train_df)

    # Reset cleaner for standardization
    cleaner_std = NHANESDataCleaner(logger)
    train_df = cleaner_std.standardize_features(train_df, feature_cols, fit=True)
    val_df = cleaner_std.standardize_features(val_df, feature_cols, fit=False)
    test_df = cleaner_std.standardize_features(test_df, feature_cols, fit=False)

    # Step 8: Save datasets
    logger.log("\n--- STEP 8: Saving Datasets ---")

    if save_format == 'parquet':
        train_df.to_parquet(output_path / 'train.parquet', index=False)
        val_df.to_parquet(output_path / 'val.parquet', index=False)
        test_df.to_parquet(output_path / 'test.parquet', index=False)
    else:
        train_df.to_csv(output_path / 'train.csv', index=False)
        val_df.to_csv(output_path / 'val.csv', index=False)
        test_df.to_csv(output_path / 'test.csv', index=False)

    logger.log(f"Saved datasets in {save_format} format")

    # Save SEQN arrays
    np.save(output_path / 'train_seqn.npy', train_df['SEQN'].values)
    np.save(output_path / 'val_seqn.npy', val_df['SEQN'].values)
    np.save(output_path / 'test_seqn.npy', test_df['SEQN'].values)
    logger.log("Saved SEQN arrays")

    # Step 9: Generate metadata
    logger.log("\n--- STEP 9: Generating Metadata ---")

    metadata = {
        'total_samples': len(df),
        'train_samples': len(train_df),
        'val_samples': len(val_df),
        'test_samples': len(test_df),
        'total_events': int(df['event'].sum()),
        'train_events': int(train_df['event'].sum()),
        'val_events': int(val_df['event'].sum()),
        'test_events': int(test_df['event'].sum()),
        'n_features': len(feature_cols),
        'cycles_included': cycles,
        'preprocessing_date': datetime.now().isoformat(),
        'feature_names': feature_cols,
        'min_sample_size_requirement': min_sample_size,
        'save_format': save_format
    }

    with open(output_path / 'preprocessing_metadata.json', 'w') as f:
        json.dump(metadata, f, indent=2)

    logger.log("Saved preprocessing metadata")

    # Generate feature documentation
    feature_docs = {}
    for col in feature_cols:
        if col in train_df.columns:
            feature_docs[col] = {
                'dtype': str(train_df[col].dtype),
                'mean': float(train_df[col].mean()) if pd.api.types.is_numeric_dtype(train_df[col]) else None,
                'std': float(train_df[col].std()) if pd.api.types.is_numeric_dtype(train_df[col]) else None,
                'missing_pct_original': 'N/A'  # Already imputed
            }

    with open(output_path / 'feature_documentation.json', 'w') as f:
        json.dump(feature_docs, f, indent=2)

    logger.log("Saved feature documentation")

    # Final summary
    logger.log("\n" + "=" * 60)
    logger.log("PREPROCESSING COMPLETE")
    logger.log("=" * 60)
    logger.log(f"Train: {len(train_df)} samples, {train_df['event'].sum()} events ({train_df['event'].mean()*100:.1f}%)")
    logger.log(f"Val: {len(val_df)} samples, {val_df['event'].sum()} events ({val_df['event'].mean()*100:.1f}%)")
    logger.log(f"Test: {len(test_df)} samples, {test_df['event'].sum()} events ({test_df['event'].mean()*100:.1f}%)")
    logger.log(f"Features: {len(feature_cols)}")
    logger.log(f"Output directory: {output_path}")

    logger.save()

    return metadata


def main():
    """Main entry point."""
    base_dir = Path(__file__).parent.parent
    data_dir = Path(__file__).parent
    output_dir = base_dir / 'NHANES_processed'

    metadata = preprocess_nhanes_data(
        data_dir=str(data_dir),
        output_dir=str(output_dir),
        cycles=['E', 'F', 'G', 'H', 'I', 'J'],
        min_sample_size=5000,
        save_format='parquet'
    )

    print(f"\nPreprocessing complete. Output saved to: {output_dir}")
    print(f"Total samples: {metadata['total_samples']}")
    print(f"Features: {metadata['n_features']}")


if __name__ == '__main__':
    main()
