"""
Main MIMIC Analysis Pipeline
Complete orchestration of the causal inference analysis
"""

import pandas as pd
import numpy as np
from pathlib import Path
from typing import Tuple, Dict, Any
import logging

from .cohort_builder import CohortBuilder
from .lab_processor import LabProcessor
from .survival_analyzer import SurvivalAnalyzer
from .causal_inference import CausalInference
from .note_processor import NoteProcessor
from ..utils.data_utils import DataUtils


class MIMICAnalysis:
    """Main analysis pipeline orchestrator"""
    
    def __init__(self, config):
        self.config = config
        self.logger = logging.getLogger(__name__)
        
        # Initialize components
        self.cohort_builder = CohortBuilder(config)
        self.lab_processor = LabProcessor(config)
        self.survival_analyzer = SurvivalAnalyzer(config)
        self.causal_inference = CausalInference(config)
        self.note_processor = NoteProcessor(config)
        
        # Create results directory
        self.config.RESULTS_DIR.mkdir(parents=True, exist_ok=True)
    
    def run_full_analysis(self, batch_outputs: pd.DataFrame) -> Tuple[pd.DataFrame, pd.DataFrame]:
        """
        Run complete causal inference analysis pipeline
        
        Args:
            batch_outputs: DataFrame with LLM batch API results
        
        Returns:
            Tuple of (analytic_dataset, performance_results)
        """
        
        self.logger.info("Starting MIMIC-IV AKI Causal Inference Analysis")
        self.logger.info("=" * 60)
        
        try:
            # Step 1: Build treatment cohort
            self.logger.info("Step 1/6: Building treatment cohort...")
            cohort = self.cohort_builder.build_cohort(self.config.HOSP_DIR)
            self.logger.info(f"Cohort built: N={len(cohort)}, VPT={int(cohort['vpt_flag'].sum())}")
            
            # Step 2: Process laboratory data and label AKI
            self.logger.info("Step 2/6: Processing laboratory data...")
            scr_ids = self.lab_processor.load_scr_itemids(self.config.HOSP_DIR)
            scr_timeseries = self.lab_processor.load_scr_timeseries(
                self.config.HOSP_DIR, cohort, scr_ids
            )
            aki_labeled = self.lab_processor.label_aki(scr_timeseries, cohort)
            self.logger.info(f"AKI labeling complete: AKI rate = {aki_labeled['aki'].mean():.3f}")
            
            # Step 3: Build time-to-event outcomes
            self.logger.info("Step 3/6: Building time-to-event outcomes...")
            survival_data = self.survival_analyzer.build_event_times(
                scr_timeseries, aki_labeled, self.config.HOSP_DIR
            )
            self.logger.info(f"Survival data: N={len(survival_data)}, events={int(survival_data['event_observed'].sum())}")
            
            # Step 4: Extract note features from batch outputs
            self.logger.info("Step 4/6: Processing clinical note features...")
            note_features = self.note_processor.extract_features_from_batch_outputs(batch_outputs)
            self.logger.info(f"Note features extracted: {len(note_features)} patients")
            
            # Step 5: Build final analytic dataset
            self.logger.info("Step 5/6: Building analytic dataset...")
            analytic_data = self.causal_inference.build_covariates(
                self.config.HOSP_DIR, aki_labeled, note_features
            )
            
            # Merge with survival outcomes
            analytic_data = analytic_data.merge(
                survival_data[["subject_id", "hadm_id", "duration_days", "event_observed"]],
                on=["subject_id", "hadm_id"],
                how="left"
            )
            
            # Remove missing data
            analytic_data = analytic_data.dropna(
                subset=["vpt_flag", "aki", "duration_days", "event_observed"]
            ).reset_index(drop=True)
            
            self.logger.info(f"Final analytic dataset: N={len(analytic_data)}, "
                           f"Events={int(analytic_data['event_observed'].sum())}, "
                           f"Treated={int(analytic_data['vpt_flag'].sum())}")
            
            # Step 6: Comparative analysis
            self.logger.info("Step 6/6: Running comparative causal inference...")
            
            # Define covariate sets
            base_covariates = ["age", "sexM", "is_emerg", "baseline"]
            llm_enhanced_covariates = base_covariates + [
                c for c in self.config.CONFOUNDERS if c in analytic_data.columns
            ]
            
            # Evaluate both models
            base_results = self.causal_inference.evaluate_covariate_set(
                analytic_data, base_covariates, "BASE"
            )
            
            llm_results = self.causal_inference.evaluate_covariate_set(
                analytic_data, llm_enhanced_covariates, "BASE+LLM"
            )
            
            # Compile results
            performance_df = pd.DataFrame([base_results, llm_results])
            
            # Save results
            self._save_results(analytic_data, performance_df)
            
            # Display results
            self._display_results(performance_df)
            
            return analytic_data, performance_df
            
        except Exception as e:
            self.logger.error(f"Analysis failed at step: {e}")
            raise
    
    def _save_results(self, analytic_data: pd.DataFrame, performance_df: pd.DataFrame):
        """Save analysis results to files"""
        
        results_file = self.config.RESULTS_DIR / "causal_inference_results.csv"
        analytic_file = self.config.RESULTS_DIR / "analytic_dataset.csv"
        
        performance_df.to_csv(results_file, index=False)
        analytic_data.to_csv(analytic_file, index=False)
        
        self.logger.info(f"Results saved to:")
        self.logger.info(f"  - {results_file}")
        self.logger.info(f"  - {analytic_file}")
    
    def _display_results(self, performance_df: pd.DataFrame):
        """Display key results in formatted output"""
        
        self.logger.info("\n" + "=" * 60)
        self.logger.info("CAUSAL INFERENCE RESULTS")
        self.logger.info("=" * 60)
        
        display_columns = [
            "covset", "k_covs",
            "mean_abs_SMD_before", "mean_abs_SMD_after", 
            "ESS", "KS_PS",
            "IPTW_HR", "IPTW_LCL", "IPTW_UCL",
            "DR_HR", "DR_LCL", "DR_UCL",
            "Evalue_point", "Evalue_CI"
        ]
        
        # Format numeric columns
        display_df = performance_df.copy()
        for col in display_df.columns:
            if col in ["IPTW_HR", "DR_HR"]:
                display_df[col] = display_df[col].round(3)
            elif col in ["IPTW_LCL", "IPTW_UCL", "DR_LCL", "DR_UCL"]:
                display_df[col] = display_df[col].round(3)
            elif col in ["mean_abs_SMD_before", "mean_abs_SMD_after", "KS_PS"]:
                display_df[col] = display_df[col].round(4)
            elif col in ["ESS"]:
                display_df[col] = display_df[col].round(1)
            elif col in ["Evalue_point", "Evalue_CI"]:
                display_df[col] = display_df[col].round(2)
        
        print(display_df[display_columns].to_string(index=False))
        
        # Summary interpretation
        base_row = display_df[display_df["covset"] == "BASE"].iloc[0]
        llm_row = display_df[display_df["covset"] == "BASE+LLM"].iloc[0]
        
        self.logger.info(f"\nKEY FINDINGS:")
        self.logger.info(f"Base Model (Traditional Covariates):")
        self.logger.info(f"  - Hazard Ratio: {base_row['IPTW_HR']:.3f} "
                        f"(95% CI: {base_row['IPTW_LCL']:.3f}-{base_row['IPTW_UCL']:.3f})")
        self.logger.info(f"  - Mean |SMD| After Weighting: {base_row['mean_abs_SMD_after']:.4f}")
        self.logger.info(f"  - Effective Sample Size: {base_row['ESS']:.1f}")
        
        self.logger.info(f"\nLLM-Enhanced Model (Traditional + Note-Derived Features):")
        self.logger.info(f"  - Hazard Ratio: {llm_row['IPTW_HR']:.3f} "
                        f"(95% CI: {llm_row['IPTW_LCL']:.3f}-{llm_row['IPTW_UCL']:.3f})")
        self.logger.info(f"  - Mean |SMD| After Weighting: {llm_row['mean_abs_SMD_after']:.4f}")
        self.logger.info(f"  - Effective Sample Size: {llm_row['ESS']:.1f}")
        
        # Balance improvement
        balance_improvement = base_row['mean_abs_SMD_after'] - llm_row['mean_abs_SMD_after']
        self.logger.info(f"\nCovariate Balance Improvement: {balance_improvement:.4f}")
        self.logger.info(f"({'Better' if balance_improvement > 0 else 'Worse'} balance with LLM features)")
    
    def generate_summary_report(self, analytic_data: pd.DataFrame, 
                               performance_df: pd.DataFrame) -> Dict[str, Any]:
        """Generate comprehensive summary report"""
        
        report = {
            "study_population": {
                "total_patients": len(analytic_data),
                "vpt_patients": int(analytic_data['vpt_flag'].sum()),
                "control_patients": int((~analytic_data['vpt_flag'].astype(bool)).sum()),
                "total_events": int(analytic_data['event_observed'].sum()),
                "vpt_events": int(analytic_data[analytic_data['vpt_flag']==1]['event_observed'].sum()),
                "control_events": int(analytic_data[analytic_data['vpt_flag']==0]['event_observed'].sum())
            },
            "performance_metrics": performance_df.to_dict('records'),
            "clinical_impact": {
                "absolute_risk_difference": self._calculate_absolute_risk_difference(analytic_data),
                "number_needed_to_harm": self._calculate_nnh(analytic_data)
            },
            "confounders_discovered": {
                confounder: int(analytic_data[confounder].sum()) 
                for confounder in self.config.CONFOUNDERS 
                if confounder in analytic_data.columns
            }
        }
        
        return report
    
    def _calculate_absolute_risk_difference(self, df: pd.DataFrame) -> float:
        """Calculate absolute risk difference between groups"""
        vpt_risk = df[df['vpt_flag']==1]['event_observed'].mean()
        control_risk = df[df['vpt_flag']==0]['event_observed'].mean()
        return float(vpt_risk - control_risk)
    
    def _calculate_nnh(self, df: pd.DataFrame) -> float:
        """Calculate number needed to harm"""
        ard = self._calculate_absolute_risk_difference(df)
        return 1.0 / ard if ard > 0 else float('inf')
