"""
Laboratory Data Processing Module
Processes serum creatinine measurements and applies KDIGO AKI criteria
"""

import pandas as pd
import numpy as np
from pathlib import Path
from typing import List, Dict
import logging
from .data_utils import DataUtils


class LabProcessor:
    """Process laboratory data for AKI detection"""
    
    def __init__(self, config):
        self.config = config
        self.logger = logging.getLogger(__name__)
    
    def load_scr_itemids(self, hosp_path: Path) -> List[int]:
        """
        Load serum creatinine lab item IDs from MIMIC-IV d_labitems
        
        Args:
            hosp_path: Path to MIMIC-IV hospital data directory
            
        Returns:
            List of item IDs corresponding to serum creatinine measurements
        """
        self.logger.info("Loading serum creatinine item IDs...")
        
        # Load lab items dictionary
        d_lab = pd.read_csv(hosp_path / "d_labitems.csv.gz", **self.config.READ_KW)
        
        # Primary filter: label contains "creatinine" AND fluid is serum/blood
        label_mask = d_lab["label"].astype("string").str.contains(
            self.config.RX_SCR_LABEL, na=False, regex=False, case=False
        )
        fluid_mask = d_lab["fluid"].astype("string").str.contains(
            self.config.RX_SCR_FLUID, na=False, regex=True, case=False
        )
        
        primary_mask = label_mask & fluid_mask
        ids = d_lab.loc[primary_mask, "itemid"].dropna().astype("Int64").astype(int).unique().tolist()
        
        # Fallback: if no items found with fluid filter, use label only
        if len(ids) == 0:
            self.logger.warning("No creatinine items found with fluid filter, using label only")
            fallback_mask = d_lab["label"].astype("string").str.contains(
                self.config.RX_SCR_LABEL, na=False, regex=False, case=False
            )
            ids = d_lab.loc[fallback_mask, "itemid"].dropna().astype("Int64").astype(int).unique().tolist()
        
        if len(ids) == 0:
            raise RuntimeError("No serum creatinine item IDs found. Check d_labitems.csv.gz filters.")
        
        self.logger.info(f"Found {len(ids)} serum creatinine item IDs: {ids[:5]}")
        return ids
    
    def load_scr_timeseries(self, hosp_path: Path, cohort: pd.DataFrame, 
                           scr_ids: List[int]) -> pd.DataFrame:
        """
        Load serum creatinine time series data for cohort patients
        
        Args:
            hosp_path: Path to MIMIC-IV hospital data directory
            cohort: Treatment cohort DataFrame
            scr_ids: List of serum creatinine item IDs
            
        Returns:
            DataFrame with serum creatinine measurements and time differences
        """
        if len(scr_ids) == 0:
            raise RuntimeError("No serum creatinine item IDs provided")
        
        self.logger.info("Loading serum creatinine time series data...")
        
        # Get relevant hospital admissions
        hadm_set = set(cohort["hadm_id"].dropna().astype(int).unique().tolist())
        self.logger.info(f"Processing lab events for {len(hadm_set):,} hospital admissions")
        
        # Define columns to keep
        keep_cols = ["subject_id", "hadm_id", "itemid", "charttime", "valuenum", "valueuom"]
        
        # Process lab events in chunks
        lab_reader = pd.read_csv(
            hosp_path / "labevents.csv.gz",
            usecols=lambda c: c in keep_cols,
            chunksize=1_000_000,
            low_memory=False
        )
        
        chunks = []
        total_processed = 0
        
        for chunk_idx, chunk in enumerate(lab_reader):
            # Filter for creatinine tests in our cohort
            chunk = chunk[chunk["itemid"].isin(scr_ids)]
            chunk = chunk[chunk["hadm_id"].isin(hadm_set)]
            
            if len(chunk) == 0:
                total_processed += 1_000_000  # Approximate chunk size
                continue
            
            # Convert chart time
            chunk["charttime"] = DataUtils.to_datetime_safe(chunk["charttime"])
            
            # Convert units to mg/dL
            chunk = self._standardize_creatinine_units(chunk)
            
            # Keep only valid measurements
            if len(chunk) > 0:
                chunks.append(chunk[["subject_id", "hadm_id", "itemid", "charttime", "scr_mgdl"]])
            
            total_processed += 1_000_000
            if chunk_idx % 10 == 0:
                self.logger.info(f"Processed {total_processed/1_000_000:.1f}M lab events...")
        
        # Combine all chunks
        if chunks:
            df = pd.concat(chunks, ignore_index=True)
            self.logger.info(f"Found {len(df):,} creatinine measurements")
        else:
            self.logger.warning("No creatinine measurements found for cohort")
            return pd.DataFrame(columns=["subject_id", "hadm_id", "itemid", "charttime", "scr_mgdl", "dt"])
        
        # Merge with cohort to get index times
        df = df.merge(
            cohort[["subject_id", "hadm_id", "index_time", "vpt_flag"]],
            on=["subject_id", "hadm_id"],
            how="inner"
        )
        
        if len(df) == 0:
            self.logger.warning("No overlapping creatinine measurements for cohort")
            return df
        
        # Calculate time differences from index time
        df = self._calculate_time_differences(df)
        
        # Clean and sort data
        df = self._clean_timeseries_data(df)
        
        self.logger.info(f"Final creatinine time series: {len(df):,} measurements")
        return df
    
    def _standardize_creatinine_units(self, chunk: pd.DataFrame) -> pd.DataFrame:
        """Convert creatinine measurements to mg/dL"""
        
        # Get unit strings
        unit_str = chunk["valueuom"].astype("string").str.lower()
        
        # Initialize standardized values
        chunk["scr_mgdl"] = chunk["valuenum"]
        
        # Identify unit types
        mg_l_mask = unit_str.str.contains(self.config.RX_MGL, na=False, regex=True)
        mg_dl_mask = unit_str.str.contains(self.config.RX_MGDL, na=False, regex=True)
        
        # Convert mg/L to mg/dL (divide by 10)
        chunk.loc[mg_l_mask, "scr_mgdl"] = chunk.loc[mg_l_mask, "valuenum"] / 10.0
        
        # Keep only measurements with recognized units
        valid_units_mask = mg_dl_mask | mg_l_mask
        chunk = chunk[valid_units_mask].copy()
        
        # Remove obviously invalid values
        if len(chunk) > 0:
            chunk = chunk[
                (chunk["scr_mgdl"] > 0) & 
                (chunk["scr_mgdl"] < 50)  # Remove extreme outliers
            ].copy()
        
        return chunk
    
    def _calculate_time_differences(self, df: pd.DataFrame) -> pd.DataFrame:
        """Calculate time differences from index time in hours"""
        
        # Ensure proper datetime types
        df["index_time"] = DataUtils.to_datetime_safe(df["index_time"])
        df = df.dropna(subset=["charttime", "index_time"]).reset_index(drop=True)
        
        # Calculate time difference in hours
        df["dt"] = (df["charttime"] - df["index_time"]).dt.total_seconds() / 3600.0
        
        return df
    
    def _clean_timeseries_data(self, df: pd.DataFrame) -> pd.DataFrame:
        """Clean and sort time series data"""
        
        # Sort by patient and time
        df = df.sort_values(["subject_id", "hadm_id", "charttime"]).reset_index(drop=True)
        
        # Convert ID columns to proper integer type
        for col in ["subject_id", "hadm_id"]:
            df[col] = df[col].astype("Int64")
        
        return df
    
    def label_aki(self, df_scr: pd.DataFrame, cohort: pd.DataFrame) -> pd.DataFrame:
        """
        Label AKI events according to KDIGO criteria
        
        Args:
            df_scr: Serum creatinine time series data
            cohort: Treatment cohort DataFrame
            
        Returns:
            DataFrame with AKI labels added
        """
        self.logger.info("Applying KDIGO AKI criteria...")
        
        if len(df_scr) == 0:
            self.logger.warning("No creatinine data available - all patients will have AKI=0")
            return self._create_empty_aki_labels(cohort)
        
        # Calculate baseline creatinine
        baseline_data = self._calculate_baseline_creatinine(df_scr)
        
        # Merge with cohort
        out = cohort.merge(baseline_data, on=["subject_id", "hadm_id"], how="left")
        
        # Handle missing baselines
        out = self._handle_missing_baselines(out, df_scr)
        
        # Apply KDIGO criteria
        out = self._apply_kdigo_criteria(out, df_scr)
        
        # Log results
        aki_rate = out["aki"].mean()
        self.logger.info(f"AKI labeling complete: rate={aki_rate:.3f}, N={len(out):,}")
        
        return out
    
    def _create_empty_aki_labels(self, cohort: pd.DataFrame) -> pd.DataFrame:
        """Create AKI labels when no creatinine data is available"""
        out = cohort.copy()
        out["baseline"] = np.nan
        out["aki48"] = 0
        out["aki7x"] = 0
        out["aki"] = 0
        return out
    
    def _calculate_baseline_creatinine(self, df_scr: pd.DataFrame) -> pd.DataFrame:
        """Calculate baseline creatinine (median from -24h to index time)"""
        baseline_window = df_scr[(df_scr["dt"] >= -24) & (df_scr["dt"] <= 0)]
        
        baseline_data = (
            baseline_window.groupby(["subject_id", "hadm_id"], as_index=False)
            .agg(baseline=("scr_mgdl", "median"))
        )
        
        self.logger.info(f"Calculated baseline creatinine for {len(baseline_data):,} patients")
        return baseline_data
    
    def _handle_missing_baselines(self, out: pd.DataFrame, df_scr: pd.DataFrame) -> pd.DataFrame:
        """Handle patients with missing baseline creatinine"""
        missing_baseline = out["baseline"].isna()
        
        if missing_baseline.any():
            # Use first creatinine measurement within 24 hours as fallback
            first24_data = (
                df_scr[(df_scr["dt"] >= 0) & (df_scr["dt"] <= 24)]
                .sort_values("charttime")
                .groupby(["subject_id", "hadm_id"], as_index=False)
                .agg(first24=("scr_mgdl", "first"))
            )
            
            out = out.merge(first24_data, on=["subject_id", "hadm_id"], how="left")
            out["baseline"] = out["baseline"].fillna(out["first24"])
            
            # Clean up temporary column
            if "first24" in out.columns:
                out = out.drop(columns=["first24"])
            
            remaining_missing = out["baseline"].isna().sum()
            self.logger.info(f"Filled {missing_baseline.sum() - remaining_missing:,} missing baselines")
            
            if remaining_missing > 0:
                self.logger.warning(f"{remaining_missing:,} patients still missing baseline creatinine")
        
        return out
    
    def _apply_kdigo_criteria(self, out: pd.DataFrame, df_scr: pd.DataFrame) -> pd.DataFrame:
        """Apply KDIGO AKI criteria (stage 1+)"""
        
        # 48-hour criterion: ≥0.3 mg/dL increase within 48 hours
        post48_data = (
            df_scr[(df_scr["dt"] > 0) & (df_scr["dt"] <= 48)]
            .groupby(["subject_id", "hadm_id"], as_index=False)
            .agg(max48=("scr_mgdl", "max"))
        )
        
        # 7-day criterion: ≥1.5x baseline within 7 days
        post7d_data = (
            df_scr[(df_scr["dt"] > 0) & (df_scr["dt"] <= 24*7)]
            .groupby(["subject_id", "hadm_id"], as_index=False)
            .agg(max7=("scr_mgdl", "max"))
        )
        
        # Merge criteria data
        out = (out.merge(post48_data, on=["subject_id", "hadm_id"], how="left")
                  .merge(post7d_data, on=["subject_id", "hadm_id"], how="left"))
        
        # Apply criteria (handle missing values)
        out["aki48"] = (
            (out["max48"].notna()) & 
            (out["baseline"].notna()) & 
            (out["max48"] >= (out["baseline"] + 0.3))
        ).astype(int)
        
        out["aki7x"] = (
            (out["max7"].notna()) & 
            (out["baseline"].notna()) & 
            (out["max7"] >= (1.5 * out["baseline"]))
        ).astype(int)
        
        # Combined AKI flag (either criterion)
        out["aki"] = ((out["aki48"] == 1) | (out["aki7x"] == 1)).astype(int)
        
        return out
    
    def get_aki_summary(self, aki_data: pd.DataFrame) -> Dict:
        """Generate AKI summary statistics"""
        
        total_patients = len(aki_data)
        aki_patients = int(aki_data["aki"].sum())
        aki48_patients = int(aki_data["aki48"].sum())
        aki7x_patients = int(aki_data["aki7x"].sum())
        
        # Calculate rates by treatment group
        vpt_aki_rate = aki_data[aki_data["vpt_flag"] == 1]["aki"].mean()
        control_aki_rate = aki_data[aki_data["vpt_flag"] == 0]["aki"].mean()
        
        summary = {
            "total_patients": total_patients,
            "aki_patients": aki_patients,
            "aki_rate": aki_patients / total_patients,
            "aki48_patients": aki48_patients,
            "aki7x_patients": aki7x_patients,
            "vpt_aki_rate": vpt_aki_rate,
            "control_aki_rate": control_aki_rate,
            "baseline_available": int((~aki_data["baseline"].isna()).sum()),
            "baseline_missing": int(aki_data["baseline"].isna().sum())
        }
        
        return summary
