"""
Treatment Cohort Construction Module
Builds vancomycin vs vancomycin+piperacillin/tazobactam cohort from MIMIC-IV prescriptions
"""

import pandas as pd
import numpy as np
from pathlib import Path
from typing import List, Tuple
import logging
from .data_utils import DataUtils


class CohortBuilder:
    """Build treatment cohorts from prescription data"""
    
    def __init__(self, config):
        self.config = config
        self.logger = logging.getLogger(__name__)
    
    def build_cohort(self, hosp_path: Path) -> pd.DataFrame:
        """
        Build cohort of patients receiving vancomycin with/without piperacillin-tazobactam
        
        Args:
            hosp_path: Path to MIMIC-IV hospital data directory
            
        Returns:
            DataFrame with columns: subject_id, hadm_id, index_time, vpt_flag
        """
        self.logger.info("Building treatment cohort from prescription data...")
        
        # Define required columns
        use_cols = ["subject_id", "hadm_id", "starttime", "stoptime", "drug"]
        
        # Process prescriptions in chunks to manage memory
        prescription_reader = pd.read_csv(
            hosp_path / "prescriptions.csv.gz",
            usecols=lambda c: c in use_cols,
            chunksize=500_000,
            **self.config.READ_KW
        )
        
        # Collect vancomycin and piperacillin-tazobactam prescriptions
        vanco_chunks, ptz_chunks = [], []
        total_processed = 0
        
        for chunk_idx, chunk in enumerate(prescription_reader):
            # Safe datetime conversion
            chunk["starttime"] = DataUtils.to_datetime_safe(chunk["starttime"])
            chunk["stoptime"] = DataUtils.to_datetime_safe(chunk["stoptime"])
            
            # Convert drug names to lowercase for matching
            drug_lower = chunk["drug"].astype("string").str.lower()
            
            # Extract vancomycin prescriptions
            vanco_mask = drug_lower.str.contains(
                self.config.RX_VANCO_SUB, na=False, regex=False
            )
            vanco_chunk = chunk[vanco_mask].rename(columns={
                "starttime": "v_start", "stoptime": "v_stop"
            })
            
            # Extract piperacillin-tazobactam prescriptions
            ptz_mask = drug_lower.str.contains(
                self.config.RX_PTZ, na=False, regex=True
            )
            ptz_chunk = chunk[ptz_mask].rename(columns={
                "starttime": "p_start", "stoptime": "p_stop"
            })
            
            # Store relevant prescriptions
            if len(vanco_chunk) > 0:
                vanco_chunks.append(
                    vanco_chunk[["subject_id", "hadm_id", "v_start", "v_stop"]]
                )
            if len(ptz_chunk) > 0:
                ptz_chunks.append(
                    ptz_chunk[["subject_id", "hadm_id", "p_start", "p_stop"]]
                )
            
            total_processed += len(chunk)
            if chunk_idx % 10 == 0:
                self.logger.info(f"Processed {total_processed:,} prescription records...")
        
        # Combine all chunks
        vanco_df = self._combine_chunks(vanco_chunks, "vancomycin")
        ptz_df = self._combine_chunks(ptz_chunks, "piperacillin-tazobactam")
        
        # Create cohort with index times
        cohort = self._create_cohort_with_index_times(vanco_df, ptz_df)
        
        # Apply VPT combination logic
        cohort = self._apply_vpt_logic(cohort)
        
        # Final data cleaning
        cohort = self._clean_cohort_data(cohort)
        
        self.logger.info(
            f"Cohort built: N={len(cohort):,}, "
            f"VPT={int(cohort['vpt_flag'].sum()):,} "
            f"({cohort['vpt_flag'].mean()*100:.1f}%)"
        )
        
        return cohort
    
    def _combine_chunks(self, chunks: List[pd.DataFrame], drug_name: str) -> pd.DataFrame:
        """Combine prescription chunks for a specific drug"""
        if chunks:
            combined = pd.concat(chunks, ignore_index=True)
            self.logger.info(f"Found {len(combined):,} {drug_name} prescriptions")
            return combined
        else:
            self.logger.warning(f"No {drug_name} prescriptions found")
            return pd.DataFrame(columns=["subject_id", "hadm_id", "v_start", "v_stop"])
    
    def _create_cohort_with_index_times(self, vanco_df: pd.DataFrame, 
                                       ptz_df: pd.DataFrame) -> pd.DataFrame:
        """Create cohort with vancomycin index times and piperacillin-tazobactam times"""
        
        # Get first vancomycin start time per admission (index time)
        if len(vanco_df) > 0:
            vanco_first = (
                vanco_df.sort_values("v_start")
                .groupby(["subject_id", "hadm_id"], as_index=False)
                .agg(index_time=("v_start", "first"))
            )
        else:
            return pd.DataFrame(columns=["subject_id", "hadm_id", "index_time", "ptz_time"])
        
        # Get first piperacillin-tazobactam start time per admission
        if len(ptz_df) > 0:
            ptz_first = (
                ptz_df.sort_values("p_start")
                .groupby(["subject_id", "hadm_id"], as_index=False)
                .agg(ptz_time=("p_start", "first"))
            )
        else:
            ptz_first = pd.DataFrame(columns=["subject_id", "hadm_id", "ptz_time"])
        
        # Merge vancomycin and piperacillin-tazobactam data
        cohort = vanco_first.merge(ptz_first, on=["subject_id", "hadm_id"], how="left")
        
        # Ensure proper datetime formatting
        cohort["index_time"] = DataUtils.to_datetime_safe(cohort["index_time"])
        if "ptz_time" in cohort.columns:
            cohort["ptz_time"] = DataUtils.to_datetime_safe(cohort["ptz_time"])
        else:
            cohort["ptz_time"] = pd.NaT
        
        return cohort
    
    def _apply_vpt_logic(self, cohort: pd.DataFrame) -> pd.DataFrame:
        """Apply VPT combination therapy logic"""
        
        # VPT flag: piperacillin-tazobactam within specified window after vancomycin
        time_window = pd.Timedelta(hours=self.config.VPT_WINDOW_HOURS)
        
        cohort["vpt_flag"] = (
            cohort["ptz_time"].notna() &
            (cohort["ptz_time"] >= cohort["index_time"]) &
            (cohort["ptz_time"] <= cohort["index_time"] + time_window)
        ).astype(int)
        
        return cohort
    
    def _clean_cohort_data(self, cohort: pd.DataFrame) -> pd.DataFrame:
        """Final data cleaning and validation"""
        
        # Remove entries without valid index time
        initial_count = len(cohort)
        cohort = cohort.dropna(subset=["index_time"]).reset_index(drop=True)
        
        if len(cohort) < initial_count:
            self.logger.info(
                f"Removed {initial_count - len(cohort):,} entries without valid index time"
            )
        
        # Convert ID columns to proper integer type
        for col in ["subject_id", "hadm_id"]:
            cohort[col] = cohort[col].astype("Int64")
        
        # Validate cohort
        self._validate_cohort(cohort)
        
        return cohort
    
    def _validate_cohort(self, cohort: pd.DataFrame) -> None:
        """Validate cohort data quality"""
        
        # Check for required columns
        required_cols = ["subject_id", "hadm_id", "index_time", "vpt_flag"]
        missing_cols = set(required_cols) - set(cohort.columns)
        if missing_cols:
            raise ValueError(f"Missing required columns: {missing_cols}")
        
        # Check for missing values in critical columns
        for col in ["subject_id", "hadm_id", "index_time"]:
            if cohort[col].isna().any():
                raise ValueError(f"Missing values found in critical column: {col}")
        
        # Check VPT flag values
        if not cohort["vpt_flag"].isin([0, 1]).all():
            raise ValueError("VPT flag contains invalid values (must be 0 or 1)")
        
        # Check for reasonable cohort size
        if len(cohort) < 1000:
            self.logger.warning(f"Cohort size appears small: {len(cohort):,} patients")
        
        # Check VPT proportion
        vpt_rate = cohort["vpt_flag"].mean()
        if vpt_rate < 0.01 or vpt_rate > 0.5:
            self.logger.warning(f"Unusual VPT rate: {vpt_rate*100:.1f}%")
        
        self.logger.info("Cohort validation completed successfully")
    
    def get_cohort_summary(self, cohort: pd.DataFrame) -> dict:
        """Generate cohort summary statistics"""
        
        vpt_count = int(cohort["vpt_flag"].sum())
        control_count = len(cohort) - vpt_count
        
        summary = {
            "total_patients": len(cohort),
            "vpt_patients": vpt_count,
            "control_patients": control_count,
            "vpt_rate": cohort["vpt_flag"].mean(),
            "unique_subjects": cohort["subject_id"].nunique(),
            "unique_admissions": cohort["hadm_id"].nunique(),
            "date_range": {
                "earliest": cohort["index_time"].min(),
                "latest": cohort["index_time"].max()
            }
        }
        
        return summary
