"""
Fixed Realistic Drug Probability Predictor

Uses ICD diagnosis codes from MIMIC data to generate clinically appropriate
drug recommendations via evidence-based mappings.

Authors: Research Team
Date: 2024
"""

import logging
from typing import Dict, List, Optional, Tuple
from dataclasses import dataclass
from pathlib import Path
import numpy as np
import json

from ..core.data_structures import ClinicalContext
from .icd_to_drug_mapper import ICDToDrugMapper, create_icd_to_drug_mapper
from .medical_rules import MedicalRulesEngine, create_medical_rules_engine
from .mimic_loader import MimicDataLoader, load_mimic_data

logger = logging.getLogger(__name__)


@dataclass
class FixedPredictionConfig:
    """Configuration for fixed realistic drug prediction."""
    
    # Data sources
    mimic_data_root: str = "data/processed/mimic"
    custom_rules_file: Optional[Path] = None
    
    # ICD mapping parameters
    use_age_adjustments: bool = True
    use_gender_adjustments: bool = True
    
    # Safety constraints
    apply_contraindications: bool = True
    apply_interactions: bool = True
    safety_weight: float = 0.3
    
    # Output configuration
    min_probability: float = 1e-6


class FixedRealisticDrugPredictor:
    """
    Fixed realistic drug probability predictor.
    
    Uses ICD diagnosis codes from MIMIC data to generate evidence-based
    drug recommendations through clinical guidelines mapping.
    """
    
    def __init__(self, config: FixedPredictionConfig):
        """
        Initialize fixed realistic drug predictor.
        
        Args:
            config: Prediction configuration
        """
        self.config = config
        
        # Initialize components
        self.icd_mapper: Optional[ICDToDrugMapper] = None
        self.rules_engine: Optional[MedicalRulesEngine] = None
        self.mimic_loader: Optional[MimicDataLoader] = None
        
        # Prediction state
        self.is_trained: bool = False
        self.drug_vocabulary: List[str] = []
        
        logger.info(f"Initialized FixedRealisticDrugPredictor")
    
    def train(self) -> None:
        """
        Train the predictor by loading ICD mapper and medical rules.
        """
        logger.info("🎯 Training fixed realistic drug predictor...")
        
        # Initialize ICD to drug mapper
        logger.info("Loading ICD to drug mapper...")
        self.icd_mapper = create_icd_to_drug_mapper()
        
        # Get drug vocabulary from mapper
        self.drug_vocabulary = self.icd_mapper.get_all_drugs()
        
        # Initialize medical rules engine
        logger.info("Loading medical rules engine...")
        self.rules_engine = create_medical_rules_engine(self.config.custom_rules_file)
        
        # Initialize MIMIC loader for real patient data
        logger.info("Loading MIMIC data loader...")
        self.mimic_loader = load_mimic_data(data_root=self.config.mimic_data_root)
        
        self.is_trained = True
        logger.info(f"✅ Training complete: {len(self.drug_vocabulary)} drugs in vocabulary")
    
    def predict_probabilities(self, 
                            clinical_context: ClinicalContext,
                            current_medications: Optional[List[str]] = None) -> np.ndarray:
        """
        Predict drug probabilities for given clinical context.
        
        Args:
            clinical_context: Patient clinical context
            current_medications: Current patient medications (for interaction checking)
            
        Returns:
            Probability distribution over drug vocabulary
        """
        if not self.is_trained:
            raise ValueError("Predictor not trained. Call train() first.")
        
        if not self.drug_vocabulary:
            raise ValueError("Drug vocabulary is empty")
        
        # Extract ICD codes from clinical context
        icd_codes = self._extract_icd_codes(clinical_context)
        
        # Generate base probabilities from ICD mappings
        base_probs = self._generate_base_probabilities_from_icd(
            icd_codes, clinical_context.age, clinical_context.gender
        )
        
        # Apply medical safety rules
        safe_probs = self._apply_safety_rules(
            probabilities=base_probs,
            clinical_context=clinical_context,
            current_medications=current_medications or []
        )
        
        # Ensure minimum probabilities
        safe_probs = np.maximum(safe_probs, self.config.min_probability)
        
        # Final normalization
        safe_probs = safe_probs / np.sum(safe_probs)
        
        logger.debug(f"Generated probabilities for context with {len(icd_codes)} ICD codes")
        
        return safe_probs
    
    def _extract_icd_codes(self, clinical_context: ClinicalContext) -> List[str]:
        """
        Extract ICD codes from clinical context.
        
        For real MIMIC integration, this would use actual ICD labels.
        For testing, use diagnoses if they match ICD codes.
        """
        icd_codes = []
        
        if clinical_context.diagnoses:
            # Map common diagnosis names to ICD codes
            diagnosis_to_icd = {
                'hypertension': '4019',
                'atrial_fibrillation': '42731', 
                'heart_failure': '4280',
                'diabetes': '25000',
                'depression': '311',
                'pneumonia': '486',
                'copd': '496',
                'renal_failure': '5849'
            }
            
            for diagnosis in clinical_context.diagnoses:
                # Try direct mapping
                if diagnosis in diagnosis_to_icd:
                    icd_codes.append(diagnosis_to_icd[diagnosis])
                # Try as direct ICD code
                elif diagnosis in self.icd_mapper.icd_drug_mapping:
                    icd_codes.append(diagnosis)
                # Partial matching
                else:
                    for name, icd in diagnosis_to_icd.items():
                        if name.lower() in diagnosis.lower():
                            icd_codes.append(icd)
                            break
        
        # Default to general condition if no mappings found
        if not icd_codes:
            icd_codes = ['4019']  # Default to hypertension (very common)
        
        return icd_codes
    
    def _generate_base_probabilities_from_icd(self, 
                                            icd_codes: List[str],
                                            age: Optional[int],
                                            gender: Optional[str]) -> np.ndarray:
        """Generate base probabilities from ICD codes using mapper."""
        
        # Use ICD mapper to get drug probabilities
        age_param = age if self.config.use_age_adjustments else None
        gender_param = gender if self.config.use_gender_adjustments else None
        
        drug_probs_dict = self.icd_mapper.generate_drug_probabilities(
            icd_codes, age=age_param, gender=gender_param
        )
        
        # Convert to numpy array aligned with drug vocabulary
        probs = np.zeros(len(self.drug_vocabulary))
        for i, drug in enumerate(self.drug_vocabulary):
            probs[i] = drug_probs_dict.get(drug, self.config.min_probability)
        
        return probs
    
    def _apply_safety_rules(self,
                          probabilities: np.ndarray,
                          clinical_context: ClinicalContext,
                          current_medications: List[str]) -> np.ndarray:
        """Apply medical safety rules to probabilities."""
        if not self.rules_engine:
            return probabilities
        
        modified_probs = probabilities.copy()
        
        # Apply contraindication rules
        if self.config.apply_contraindications:
            modified_probs = self.rules_engine.apply_contraindication_rules(
                probabilities=modified_probs,
                drug_vocabulary=self.drug_vocabulary,
                clinical_context=clinical_context
            )
        
        # Apply interaction rules
        if self.config.apply_interactions and current_medications:
            modified_probs = self.rules_engine.apply_interaction_rules(
                probabilities=modified_probs,
                drug_vocabulary=self.drug_vocabulary,
                current_medications=current_medications
            )
        
        # Blend with safety considerations
        if self.config.safety_weight > 0:
            safety_scores = np.array([
                self.rules_engine.get_safety_score(drug, clinical_context)
                for drug in self.drug_vocabulary
            ])
            
            # Weighted combination
            modified_probs = (
                (1 - self.config.safety_weight) * modified_probs +
                self.config.safety_weight * safety_scores * modified_probs
            )
        
        return modified_probs
    
    def predict_top_k(self,
                     clinical_context: ClinicalContext,
                     k: int = 5,
                     current_medications: Optional[List[str]] = None) -> List[Tuple[str, float]]:
        """
        Predict top-k drug recommendations.
        
        Args:
            clinical_context: Patient clinical context
            k: Number of top recommendations to return
            current_medications: Current medications for interaction checking
            
        Returns:
            List of (drug_name, probability) tuples, sorted by probability
        """
        probabilities = self.predict_probabilities(clinical_context, current_medications)
        
        # Get top-k indices
        top_indices = np.argsort(probabilities)[-k:][::-1]
        
        # Return as drug name, probability pairs
        recommendations = []
        for idx in top_indices:
            drug_name = self.drug_vocabulary[idx]
            prob = probabilities[idx]
            recommendations.append((drug_name, prob))
        
        return recommendations
    
    def get_drug_vocabulary(self) -> List[str]:
        """Get the drug vocabulary."""
        return self.drug_vocabulary.copy()
    
    def get_icd_mapping_stats(self) -> Dict[str, any]:
        """Get statistics about the ICD mappings."""
        if not self.icd_mapper:
            return {}
        return self.icd_mapper.get_mapping_statistics()


def create_fixed_trained_predictor(config: Optional[FixedPredictionConfig] = None) -> FixedRealisticDrugPredictor:
    """Create and train a fixed realistic drug predictor."""
    if config is None:
        config = FixedPredictionConfig()
    
    predictor = FixedRealisticDrugPredictor(config)
    predictor.train()
    return predictor


# For testing
if __name__ == "__main__":
    print("🧪 Testing Fixed Realistic Drug Predictor")
    print("=" * 50)
    
    # Create and train predictor
    predictor = create_fixed_trained_predictor()
    
    # Test with different clinical contexts
    test_contexts = [
        ClinicalContext(
            patient_id="test_001",
            age=65,
            gender="M",
            diagnoses=["hypertension", "diabetes"],
            current_medications=[]
        ),
        ClinicalContext(
            patient_id="test_002",
            age=45,
            gender="F", 
            diagnoses=["atrial_fibrillation"],
            current_medications=["metoprolol"]
        ),
        ClinicalContext(
            patient_id="test_003",
            age=75,
            gender="M",
            diagnoses=["heart_failure"],
            current_medications=[]
        )
    ]
    
    for i, context in enumerate(test_contexts):
        print(f"\n患者 {i+1}: {context.diagnoses}, 年龄{context.age}, {context.gender}")
        
        # Get probabilities
        probs = predictor.predict_probabilities(context)
        print(f"概率分布统计:")
        print(f"  最大: {np.max(probs):.6f}")
        print(f"  最小: {np.min(probs):.6f}")
        print(f"  标准差: {np.std(probs):.6f}")
        
        # Get top recommendations
        top_5 = predictor.predict_top_k(context, k=5)
        print(f"前5个推荐:")
        for drug, prob in top_5:
            print(f"  {drug}: {prob:.6f}")
    
    print("\n" + "="*50)
    print("✅ Fixed Realistic Drug Predictor 测试完成！")


