"""
Clinical Note Processing Module
Extracts LLM-derived clinical confounders from discharge summaries
"""

import pandas as pd
import numpy as np
import json
import re
from pathlib import Path
from typing import Optional, Tuple, Dict, Any
import logging


class NoteProcessor:
    """Process clinical notes and extract LLM-derived features"""
    
    def __init__(self, config):
        self.config = config
        self.logger = logging.getLogger(__name__)
    
    def extract_features_from_batch_outputs(self, batch_outputs: pd.DataFrame) -> pd.DataFrame:
        """
        Extract confounder features from LLM batch API outputs
        
        Args:
            batch_outputs: DataFrame with LLM batch API results containing:
                - custom_id: Format '{subject_id}_{hadm_id}'
                - assistant_text: JSON string with confounder flags
                - Optional quality filters: http_status, finish_reason, error
        
        Returns:
            DataFrame with [subject_id, hadm_id] + binary confounder flags
        """
        self.logger.info("Extracting features from LLM batch outputs...")
        
        # Validate required columns
        required_columns = {"custom_id", "assistant_text"}
        missing_columns = required_columns - set(batch_outputs.columns)
        if missing_columns:
            raise ValueError(f"Missing required columns: {missing_columns}")
        
        df = batch_outputs.copy()
        initial_count = len(df)
        
        # Apply quality filters if available
        df = self._apply_quality_filters(df)
        
        self.logger.info(
            f"After quality filtering: {len(df):,}/{initial_count:,} "
            f"({len(df)/initial_count*100:.1f}%) records retained"
        )
        
        # Parse patient identifiers
        df = self._parse_patient_identifiers(df)
        
        # Extract JSON features
        df = self._extract_json_features(df)
        
        # Handle duplicates and create final dataset
        result = self._create_final_dataset(df)
        
        self.logger.info(f"Feature extraction complete: {len(result):,} patients")
        
        return result
    
    def _apply_quality_filters(self, df: pd.DataFrame) -> pd.DataFrame:
        """Apply quality filters to LLM batch outputs"""
        
        filters_applied = []
        
        # Filter by HTTP status
        if "http_status" in df.columns:
            df = df[df["http_status"] == 200]
            filters_applied.append("HTTP 200")
        
        # Filter by completion reason
        if "finish_reason" in df.columns:
            df = df[df["finish_reason"].astype(str).str.lower() == "stop"]
            filters_applied.append("finish_reason=stop")
        
        # Filter out errors
        if "error" in df.columns:
            df = df[df["error"].isna()]
            filters_applied.append("no errors")
        
        if filters_applied:
            self.logger.info(f"Applied quality filters: {', '.join(filters_applied)}")
        
        return df
    
    def _parse_patient_identifiers(self, df: pd.DataFrame) -> pd.DataFrame:
        """Parse patient identifiers from custom_id field"""
        
        # Parse custom_id to extract subject_id and hadm_id
        id_pairs = df["custom_id"].apply(self._parse_custom_id)
        df["subject_id"] = [pair[0] for pair in id_pairs]
        df["hadm_id"] = [pair[1] for pair in id_pairs]
        
        # Remove rows with invalid IDs
        valid_ids = ~df["hadm_id"].isna()
        invalid_count = (~valid_ids).sum()
        
        if invalid_count > 0:
            self.logger.warning(f"Removed {invalid_count:,} records with invalid patient IDs")
        
        df = df[valid_ids].copy()
        
        # Convert to proper integer types
        df["subject_id"] = df["subject_id"].astype("Int64")
        df["hadm_id"] = df["hadm_id"].astype("Int64")
        
        return df
    
    def _parse_custom_id(self, custom_id: str) -> Tuple[Optional[int], Optional[int]]:
        """
        Parse custom_id format: 'subject_id_hadm_id' -> (subject_id, hadm_id)
        
        Args:
            custom_id: String in format '{subject_id}_{hadm_id}'
            
        Returns:
            Tuple of (subject_id, hadm_id) or (None, None) if invalid
        """
        if not isinstance(custom_id, str):
            return (None, None)
        
        custom_id = custom_id.strip()
        if not custom_id or "_" not in custom_id:
            return (None, None)
        
        parts = custom_id.split("_", 1)
        if len(parts) != 2:
            return (None, None)
        
        try:
            subject_id = int(parts[0]) if parts[0].isdigit() else None
            hadm_id = int(parts[1]) if parts[1].isdigit() else None
            return (subject_id, hadm_id)
        except (ValueError, TypeError):
            return (None, None)
    
    def _extract_json_features(self, df: pd.DataFrame) -> pd.DataFrame:
        """Extract confounder features from assistant_text JSON"""
        
        features_list = []
        parse_errors = 0
        
        for text in df["assistant_text"].astype(str):
            json_obj = self._parse_json_safely(text)
            
            feature_row = {}
            for confounder in self.config.CONFOUNDERS:
                value = json_obj.get(confounder, 0)
                try:
                    # Convert to integer (0 or 1)
                    feature_row[confounder] = int(value)
                except (ValueError, TypeError):
                    # Convert any truthy value to 1, falsy to 0
                    feature_row[confounder] = int(bool(value))
            
            features_list.append(feature_row)
            
            # Count parse errors
            if not json_obj:
                parse_errors += 1
        
        if parse_errors > 0:
            self.logger.warning(
                f"JSON parse errors: {parse_errors:,}/{len(df):,} "
                f"({parse_errors/len(df)*100:.1f}%) - using default values"
            )
        
        # Create features DataFrame
        features_df = pd.DataFrame(features_list, index=df.index)
        
        # Combine with patient identifiers
        result = pd.concat([
            df[["subject_id", "hadm_id"]], 
            features_df
        ], axis=1)
        
        return result
    
    def _parse_json_safely(self, text: str) -> Dict[str, Any]:
        """
        Safely parse JSON from LLM output with error handling
        
        Args:
            text: Raw text response from LLM
            
        Returns:
            Dictionary with parsed JSON or empty dict if parsing fails
        """
        if not isinstance(text, str) or not text.strip():
            return {}
        
        text = text.strip()
        
        # Remove markdown code blocks if present
        if text.startswith("```"):
            text = re.sub(
                r"^```(?:json)?\s*|\s*```$", "", text, 
                flags=re.IGNORECASE | re.DOTALL
            ).strip()
        
        # Try direct JSON parsing first
        try:
            return json.loads(text)
        except (json.JSONDecodeError, TypeError):
            pass
        
        # Extract JSON object from text using regex
        json_match = re.search(r"\{.*\}", text, flags=re.DOTALL)
        if json_match:
            try:
                return json.loads(json_match.group(0))
            except (json.JSONDecodeError, TypeError):
                pass
        
        # If all parsing fails, return empty dict
        return {}
    
    def _create_final_dataset(self, df: pd.DataFrame) -> pd.DataFrame:
        """Create final dataset handling duplicates and validation"""
        
        # Handle duplicates by taking maximum (OR operation for binary flags)
        # This handles cases where a patient has multiple notes processed
        aggregation_dict = {col: "max" for col in self.config.CONFOUNDERS}
        
        result = (
            df.groupby(["subject_id", "hadm_id"], dropna=False, as_index=False)
            .agg(aggregation_dict)
        )
        
        # Ensure all confounders are binary integers
        for confounder in self.config.CONFOUNDERS:
            result[confounder] = result[confounder].fillna(0).astype(int)
        
        # Final validation
        self._validate_features_dataset(result)
        
        return result
    
    def _validate_features_dataset(self, df: pd.DataFrame) -> None:
        """Validate the final features dataset"""
        
        # Check for missing patient IDs
        if df["hadm_id"].isna().any() or df["subject_id"].isna().any():
            raise ValueError("Final dataset contains missing patient IDs")
        
        # Check that all confounders are present
        missing_confounders = set(self.config.CONFOUNDERS) - set(df.columns)
        if missing_confounders:
            raise ValueError(f"Missing confounder columns: {missing_confounders}")
        
        # Check that all confounder values are binary
        for confounder in self.config.CONFOUNDERS:
            unique_vals = df[confounder].unique()
            if not set(unique_vals).issubset({0, 1}):
                raise ValueError(f"Non-binary values found in {confounder}: {unique_vals}")
        
        self.logger.info("Feature dataset validation successful")
    
    def get_confounder_summary(self, features_df: pd.DataFrame) -> Dict[str, Any]:
        """Generate summary statistics for extracted confounders"""
        
        summary = {
            "total_patients": len(features_df),
            "confounder_prevalence": {},
            "confounder_counts": {}
        }
        
        # Calculate prevalence and counts for each confounder
        for confounder in self.config.CONFOUNDERS:
            if confounder in features_df.columns:
                count = int(features_df[confounder].sum())
                prevalence = count / len(features_df)
                
                summary["confounder_prevalence"][confounder] = prevalence
                summary["confounder_counts"][confounder] = count
        
        # Calculate patients with any confounders
        any_confounders = features_df[self.config.CONFOUNDERS].any(axis=1)
        summary["patients_with_any_confounder"] = int(any_confounders.sum())
        summary["patients_with_any_confounder_rate"] = any_confounders.mean()
        
        # Calculate mean number of confounders per patient
        confounders_per_patient = features_df[self.config.CONFOUNDERS].sum(axis=1)
        summary["mean_confounders_per_patient"] = float(confounders_per_patient.mean())
        
        return summary
    
    def create_batch_requests(self, notes_df: pd.DataFrame, 
                            index_time_map: Dict[Tuple[int, int], Any],
                            output_file: Path) -> int:
        """
        Create batch requests for LLM processing
        
        Args:
            notes_df: DataFrame with discharge notes
            index_time_map: Mapping from (subject_id, hadm_id) to index_time
            output_file: Path to output JSONL file
            
        Returns:
            Number of requests created
        """
        self.logger.info("Creating batch requests for LLM processing...")
        
        # LLM prompt template
        prompt_template = self._get_prompt_template()
        
        request_count = 0
        
        with open(output_file, "w", encoding="utf-8") as f:
            for (subject_id, hadm_id), group in notes_df.groupby(["subject_id", "hadm_id"]):
                # Combine all notes for this admission
                combined_text = "\n".join(group["text"].astype(str).tolist())
                
                # Truncate if too long
                excerpt = combined_text[:self.config.MAX_NOTE_CHARS]
                
                # Get index time
                index_time = index_time_map.get((int(subject_id), int(hadm_id)))
                if pd.isna(index_time):
                    continue
                
                # Format index time as ISO string
                index_time_iso = pd.to_datetime(index_time).tz_localize("UTC").isoformat()
                
                # Create prompt
                prompt = prompt_template.format(
                    index_time_iso=index_time_iso,
                    note_text=excerpt
                )
                
                # Create batch request
                request = {
                    "custom_id": f"{subject_id}_{hadm_id}",
                    "method": "POST",
                    "url": "/v1/chat/completions",
                    "body": {
                        "model": self.config.LLM_MODEL,
                        "temperature": self.config.LLM_TEMPERATURE,
                        "max_tokens": self.config.LLM_MAX_TOKENS,
                        "response_format": {"type": "json_object"},
                        "messages": [
                            {"role": "system", "content": "Return strict JSON only."},
                            {"role": "user", "content": prompt}
                        ]
                    }
                }
                
                f.write(json.dumps(request) + "\n")
                request_count += 1
                
                if request_count % 1000 == 0:
                    self.logger.info(f"Created {request_count:,} batch requests...")
        
        self.logger.info(f"Created {request_count:,} batch requests in {output_file}")
        return request_count
    
    def _get_prompt_template(self) -> str:
        """Get the LLM prompt template for confounder extraction"""
        
        return """You are assisting a causal inference study analyzing drug-drug interaction effects on acute kidney injury. The exposure of interest is vancomycin combined with piperacillin/tazobactam versus vancomycin monotherapy.

Your ONLY task: read the discharge note and identify **pre-treatment** (pre-admission or at presentation) risk factors that could confound the relationship between antibiotic choice and AKI risk.

CRITICAL TEMPORAL REASONING RULES:
- Consider ONLY information existing **before or at presentation** relative to index_time = {index_time_iso}.
- DO NOT mark conditions/events clearly arising during hospitalization, hospital course, ICU interventions, inpatient treatments, or discharge medications. Those are potential colliders that can bias causal estimates.
- If timing is ambiguous, be conservative and mark 0. Prefer false negatives over false positives.

CONFOUNDER DEFINITIONS:

f_ckd_pre (Chronic Kidney Disease):
- CKD stages 3-5 (eGFR <60 mL/min/1.73m² for >3 months)  
- Baseline creatinine >1.5× normal for >3 months
- Established dialysis dependence or kidney transplant
- Clinical phrases: "chronic renal insufficiency," "baseline kidney disease," "long-standing nephropathy"

f_dm_pre (Diabetes Mellitus):
- Documented diabetes history (Type 1, Type 2, or secondary)
- Home antidiabetic medications (insulin, metformin, sulfonylureas, etc.)
- HbA1c >6.5% on admission or within 3 months prior
- Diabetic complications (retinopathy, neuropathy, nephropathy)

f_hf_pre (Heart Failure):
- Documented heart failure history of any phenotype (HFrEF, HFpEF, acute/chronic)
- LVEF <50% on prior echocardiography (not during current admission)
- Chronic heart failure medications for HF indication
- Clinical context indicating heart failure regardless of EF

f_liver_pre (Liver Disease):
- Chronic liver disease of any etiology (viral, alcoholic, NASH, etc.)
- Elevated hepatic enzymes >3 months prior to admission  
- Documented cirrhosis, portal hypertension, ascites
- End-stage liver disease or liver transplant history

f_nephrotox_pre (Nephrotoxic Drug Exposure):
- Home medications known for nephrotoxicity: NSAIDs, ACE inhibitors/ARBs (for hypertension), aminoglycosides, calcineurin inhibitors
- High-dose loop/thiazide diuretics present before antibiotic initiation
- Exclude: medications started during hospitalization

OUTPUT FORMAT:
Return ONLY a single-line JSON with binary (0/1) values:
{{
  "f_ckd_pre": 0 or 1,
  "f_dm_pre": 0 or 1, 
  "f_hf_pre": 0 or 1,
  "f_liver_pre": 0 or 1,
  "f_nephrotox_pre": 0 or 1
}}

Discharge note:
---
{note_text}
---"""