"""
Causal Inference Module
Implements propensity score methods and Cox regression for treatment effect estimation
"""

import pandas as pd
import numpy as np
from pathlib import Path
from typing import List, Tuple, Dict, Any
import logging
from sklearn.linear_model import LogisticRegression
from scipy.stats import ks_2samp
from .data_utils import DataUtils

try:
    from lifelines import CoxPHFitter
except ImportError:
    CoxPHFitter = None


class CausalInference:
    """Propensity score and causal inference methods"""
    
    def __init__(self, config):
        self.config = config
        self.logger = logging.getLogger(__name__)
        
        if CoxPHFitter is None:
            self.logger.warning("lifelines not installed - Cox regression will not be available")
    
    def build_covariates(self, hosp_path: Path, aki_data: pd.DataFrame, 
                        note_features: pd.DataFrame) -> pd.DataFrame:
        """
        Build covariate matrix from multiple data sources
        
        Args:
            hosp_path: Path to MIMIC-IV hospital data directory
            aki_data: DataFrame with AKI labels
            note_features: DataFrame with LLM-derived confounders
            
        Returns:
            Complete covariate matrix for causal analysis
        """
        self.logger.info("Building covariate matrix...")
        
        # Load patient demographics
        patient_data = pd.read_csv(
            hosp_path / "patients.csv.gz",
            usecols=["subject_id", "gender", "anchor_age"],
            **self.config.READ_KW
        )
        
        # Load admission data
        admission_data = pd.read_csv(
            hosp_path / "admissions.csv.gz", 
            usecols=["subject_id", "hadm_id", "admittime", "dischtime", "admission_type"],
            parse_dates=["admittime", "dischtime"],
            **self.config.READ_KW
        )
        
        # Safe datetime conversion
        admission_data["admittime"] = DataUtils.to_datetime_safe(admission_data["admittime"])
        admission_data["dischtime"] = DataUtils.to_datetime_safe(admission_data["dischtime"])
        
        # Merge all data sources
        covariates = (
            aki_data
            .merge(admission_data, on=["subject_id", "hadm_id"], how="left")
            .merge(patient_data, on="subject_id", how="left") 
            .merge(note_features, on=["subject_id", "hadm_id"], how="left")
        )
        
        # Create derived variables
        covariates = self._create_derived_variables(covariates)
        
        # Handle missing confounder values
        covariates = self._handle_missing_confounders(covariates)
        
        self.logger.info(f"Covariate matrix built: {len(covariates):,} patients, {covariates.shape[1]} variables")
        
        return covariates
    
    def _create_derived_variables(self, df: pd.DataFrame) -> pd.DataFrame:
        """Create derived demographic and clinical variables"""
        
        # Age (continuous)
        df["age"] = pd.to_numeric(df["anchor_age"], errors="coerce")
        
        # Male sex (binary)
        df["sexM"] = (df["gender"] == "M").astype(int)
        
        # Emergency admission (binary)
        df["is_emerg"] = (
            df["admission_type"].astype("string")
            .str.contains(self.config.RX_EMERG_SUB, na=False, regex=False)
            .astype(int)
        )
        
        return df
    
    def _handle_missing_confounders(self, df: pd.DataFrame) -> pd.DataFrame:
        """Handle missing values in LLM-derived confounders"""
        
        missing_counts = {}
        for confounder in self.config.CONFOUNDERS:
            if confounder in df.columns:
                missing_count = df[confounder].isna().sum()
                missing_counts[confounder] = missing_count
                
                # Fill missing with 0 (conservative approach)
                df[confounder] = df[confounder].fillna(0).astype(int)
        
        total_missing = sum(missing_counts.values())
        if total_missing > 0:
            self.logger.info(f"Filled {total_missing:,} missing confounder values with 0")
        
        return df
    
    def fit_propensity_scores(self, data: pd.DataFrame, 
                             covariates: List[str]) -> Tuple[np.ndarray, np.ndarray, float]:
        """
        Fit propensity score model and calculate stabilized weights
        
        Args:
            data: Analysis dataset
            covariates: List of covariate names
            
        Returns:
            Tuple of (propensity_scores, stabilized_weights, effective_sample_size)
        """
        self.logger.info(f"Fitting propensity score model with {len(covariates)} covariates...")
        
        # Validate covariates
        covariates = [c for c in dict.fromkeys(covariates) if c in data.columns]
        if not covariates:
            raise ValueError("No valid covariates provided for propensity score estimation")
        
        # Prepare covariate matrix
        X = data[covariates].copy()
        for col in covariates:
            X[col] = pd.to_numeric(X[col], errors="coerce")
        X = X.fillna(X.median(numeric_only=True))
        
        # Treatment indicator
        treatment = data["vpt_flag"].astype(int)
        
        # Fit logistic regression
        ps_model = LogisticRegression(
            max_iter=400, 
            solver="lbfgs", 
            random_state=self.config.RANDOM_STATE
        )
        ps_model.fit(X, treatment)
        
        # Calculate propensity scores with clipping
        propensity_scores = np.clip(
            ps_model.predict_proba(X)[:, 1], 
            *self.config.PS_CLIP
        )
        
        # Calculate stabilized inverse probability weights
        treatment_rate = float(treatment.mean())
        stabilized_weights = np.where(
            treatment == 1,
            treatment_rate / propensity_scores,
            (1 - treatment_rate) / (1 - propensity_scores)
        )
        
        # Trim extreme weights
        weight_bounds = np.quantile(stabilized_weights, self.config.W_TRIM)
        stabilized_weights = np.clip(stabilized_weights, *weight_bounds)
        
        # Calculate effective sample size
        effective_sample_size = DataUtils.calculate_effective_sample_size(stabilized_weights)
        
        self.logger.info(f"Propensity score fitting complete - ESS: {effective_sample_size:.1f}")
        
        return propensity_scores, stabilized_weights, effective_sample_size
    
    def cox_regression_analysis(self, data: pd.DataFrame, weights: np.ndarray, 
                               covariates: List[str]) -> Tuple[Tuple[float, float, float], Tuple[float, float, float]]:
        """
        Perform Cox regression with IPTW and doubly robust estimation
        
        Args:
            data: Analysis dataset
            weights: Stabilized inverse probability weights
            covariates: List of covariate names for doubly robust estimation
            
        Returns:
            Tuple of ((IPTW_HR, IPTW_LCL, IPTW_UCL), (DR_HR, DR_LCL, DR_UCL))
        """
        if CoxPHFitter is None:
            raise RuntimeError("lifelines is not installed - Cox regression unavailable")
        
        self.logger.info("Performing Cox regression analysis...")
        
        # Prepare survival data
        survival_data = pd.DataFrame({
            "time": pd.to_numeric(data["duration_days"], errors="coerce"),
            "event": data["event_observed"].astype(int),
            "treat": data["vpt_flag"].astype(int),
            "sw": weights
        }).dropna(subset=["time", "event", "treat", "sw"])
        
        if survival_data["event"].sum() == 0:
            raise RuntimeError("No events in time-to-event window - cannot fit Cox model")
        
        # IPTW Cox model (treatment effect only)
        iptw_results = self._fit_iptw_cox_model(survival_data)
        
        # Doubly robust Cox model (treatment + covariates)
        dr_results = self._fit_doubly_robust_cox_model(survival_data, data, covariates)
        
        self.logger.info("Cox regression analysis complete")
        
        return iptw_results, dr_results
    
    def _fit_iptw_cox_model(self, survival_data: pd.DataFrame) -> Tuple[float, float, float]:
        """Fit IPTW Cox model"""
        
        cph_iptw = CoxPHFitter(penalizer=self.config.COX_PENALIZER)
        cph_iptw.fit(
            survival_data, 
            duration_col="time", 
            event_col="event", 
            weights_col="sw", 
            robust=True
        )
        
        # Extract hazard ratio and confidence intervals
        iptw_hr = float(np.exp(cph_iptw.params_["treat"]))
        iptw_ci = np.exp(cph_iptw.confidence_intervals_.loc["treat"].values)
        iptw_lcl, iptw_ucl = iptw_ci
        
        return (iptw_hr, float(iptw_lcl), float(iptw_ucl))
    
    def _fit_doubly_robust_cox_model(self, survival_data: pd.DataFrame, 
                                   full_data: pd.DataFrame, 
                                   covariates: List[str]) -> Tuple[float, float, float]:
        """Fit doubly robust Cox model"""
        
        # Prepare covariate data
        X_covs = full_data[covariates].copy()
        for col in covariates:
            X_covs[col] = pd.to_numeric(X_covs[col], errors="coerce")
        X_covs = X_covs.fillna(X_covs.median(numeric_only=True))
        
        # Combine with survival data
        dr_data = pd.concat([
            survival_data.reset_index(drop=True), 
            X_covs.reset_index(drop=True)
        ], axis=1)
        
        # Fit doubly robust model
        cph_dr = CoxPHFitter(penalizer=self.config.COX_PENALIZER)
        cph_dr.fit(
            dr_data, 
            duration_col="time", 
            event_col="event", 
            weights_col="sw", 
            robust=True
        )
        
        # Extract results for treatment variable
        dr_hr = float(np.exp(cph_dr.params_["treat"]))
        dr_ci = np.exp(cph_dr.confidence_intervals_.loc["treat"].values)
        dr_lcl, dr_ucl = dr_ci
        
        return (dr_hr, float(dr_lcl), float(dr_ucl))
    
    def evaluate_covariate_set(self, data: pd.DataFrame, covariates: List[str], 
                              label: str = "DEFAULT") -> Dict[str, Any]:
        """
        Comprehensive evaluation of a covariate set for causal inference
        
        Args:
            data: Analysis dataset
            covariates: List of covariate names
            label: Descriptive label for this covariate set
            
        Returns:
            Dictionary with comprehensive evaluation metrics
        """
        self.logger.info(f"Evaluating covariate set: {label}")
        
        # Validate covariates
        covariates = [c for c in dict.fromkeys(covariates) if c in data.columns]
        if not covariates:
            raise ValueError("No valid covariates for evaluation")
        
        # Fit propensity scores and weights
        propensity_scores, weights, ess = self.fit_propensity_scores(data, covariates)
        
        # Calculate covariate balance
        balance_metrics = self._calculate_covariate_balance(data, covariates, weights)
        
        # Assess propensity score overlap
        overlap_metrics = self._assess_propensity_score_overlap(propensity_scores, data["vpt_flag"])
        
        # Perform Cox regression
        (iptw_hr, iptw_lcl, iptw_ucl), (dr_hr, dr_lcl, dr_ucl) = \
            self.cox_regression_analysis(data, weights, covariates)
        
        # Calculate E-values for sensitivity analysis
        evalue_point, evalue_ci = DataUtils.evalue_from_hr(iptw_hr, iptw_lcl, iptw_ucl)
        
        # Compile results
        results = {
            "covset": label,
            "k_covs": len(covariates),
            "mean_abs_SMD_before": balance_metrics["mean_smd_before"],
            "mean_abs_SMD_after": balance_metrics["mean_smd_after"],
            "ESS": ess,
            "KS_PS": overlap_metrics["ks_statistic"],
            "IPTW_HR": iptw_hr,
            "IPTW_LCL": iptw_lcl,
            "IPTW_UCL": iptw_ucl,
            "IPTW_CI_width": iptw_ucl - iptw_lcl,
            "DR_HR": dr_hr,
            "DR_LCL": dr_lcl,
            "DR_UCL": dr_ucl,
            "DR_CI_width": dr_ucl - dr_lcl,
            "Evalue_point": evalue_point,
            "Evalue_CI": evalue_ci
        }
        
        return results
    
    def _calculate_covariate_balance(self, data: pd.DataFrame, covariates: List[str], 
                                   weights: np.ndarray) -> Dict[str, float]:
        """Calculate standardized mean differences before and after weighting"""
        
        treatment = data["vpt_flag"].astype(int).values
        
        # Prepare covariate matrix
        X = data[covariates].copy()
        for col in covariates:
            X[col] = pd.to_numeric(X[col], errors="coerce")
        X = X.fillna(X.median(numeric_only=True))
        
        # Calculate SMDs
        smd_before = [abs(DataUtils.standardized_mean_difference(X[c].values, treatment, None)) 
                     for c in X.columns]
        smd_after = [abs(DataUtils.standardized_mean_difference(X[c].values, treatment, weights)) 
                    for c in X.columns]
        
        return {
            "mean_smd_before": float(np.nanmean(smd_before)),
            "mean_smd_after": float(np.nanmean(smd_after))
        }
    
    def _assess_propensity_score_overlap(self, propensity_scores: np.ndarray, 
                                       treatment: pd.Series) -> Dict[str, float]:
        """Assess propensity score overlap between treatment groups"""
        
        treatment_array = treatment.astype(int).values
        
        try:
            # Kolmogorov-Smirnov test for distribution differences
            ks_stat = float(ks_2samp(
                propensity_scores[treatment_array == 1], 
                propensity_scores[treatment_array == 0]
            ).statistic)
        except Exception:
            # Fallback calculation if scipy fails
            self.logger.warning("Using fallback KS calculation")
            grid = np.linspace(0, 1, 200)
            
            treated_cdf = np.searchsorted(
                np.sort(propensity_scores[treatment_array == 1]),
                grid,
                side="right"
            ) / max(1, (treatment_array == 1).sum())

            control_cdf = np.searchsorted(
                np.sort(propensity_scores[treatment_array == 0]),
                grid,
                side="right"
            ) / max(1, (treatment_array == 0).sum())

            ks_stat = float(np.max(np.abs(treated_cdf - control_cdf)))

        return {"ks_statistic": ks_stat}

    def get_covariate_overlap(covariates: List[str], data: pd.DataFrame) -> Dict[str, Any]:
        """ Assess overlap in covariate distributions between treatment groups """
        overlap_results = {}
        for covariate in covariates:
            treated = data[data["vpt_flag"] == 1][covariate]
            control = data[data["vpt_flag"] == 0][covariate]
            ks_stat = ks_2samp(treated, control).statistic
            overlap_results[covariate] = {"ks_statistic": ks_stat}
        
        return overlap_results