"""
Realistic Drug Probability Predictor

This module integrates MIMIC-IV statistical patterns with medical rules
to generate realistic P(drug|clinical_context) distributions, replacing
random probability generation with evidence-based predictions.

Authors: Research Team
Date: 2024
"""

import logging
from typing import Dict, List, Optional, Tuple
from dataclasses import dataclass
from pathlib import Path
import numpy as np
import pandas as pd
import pickle

from ..core.data_structures import ClinicalContext
from .prescription_analytics import PrescriptionAnalytics, create_prescription_analytics
from .medical_rules import MedicalRulesEngine, create_medical_rules_engine

logger = logging.getLogger(__name__)


@dataclass
class PredictionConfig:
    """Configuration for realistic drug prediction."""
    
    # Data sources
    mimic_data_root: str = "data/processed/mimic"
    custom_rules_file: Optional[Path] = None
    
    # Statistical analysis parameters
    min_pattern_support: int = 10
    smoothing_alpha: float = 0.1
    
    # Prediction method
    method: str = 'hybrid'  # 'pattern', 'conditional', 'hybrid'
    
    # Safety constraints
    apply_contraindications: bool = True
    apply_interactions: bool = True
    safety_weight: float = 0.3
    
    # Output configuration
    min_probability: float = 1e-6  # Minimum non-zero probability
    cache_patterns: bool = True
    patterns_cache_file: Optional[Path] = None
    

class RealisticDrugPredictor:
    """
    Realistic drug probability predictor.
    
    Combines MIMIC-IV statistical patterns with medical safety rules
    to generate evidence-based P(drug|clinical_context) distributions.
    
    This replaces random probability generation with realistic predictions
    based on real clinical data and medical guidelines.
    """
    
    def __init__(self, config: PredictionConfig):
        """
        Initialize realistic drug predictor.
        
        Args:
            config: Prediction configuration
        """
        self.config = config
        
        # Initialize components
        self.analytics: Optional[PrescriptionAnalytics] = None
        self.rules_engine: Optional[MedicalRulesEngine] = None
        
        # Prediction state
        self.is_trained: bool = False
        self.drug_vocabulary: List[str] = []
        
        logger.info(f"Initialized RealisticDrugPredictor with method={config.method}")
    
    def train(self) -> None:
        """
        Train the predictor by analyzing MIMIC-IV patterns and loading medical rules.
        
        This must be called before making predictions.
        """
        logger.info("🎯 Training realistic drug predictor...")
        
        # Check if cached patterns exist
        if (self.config.cache_patterns and 
            self.config.patterns_cache_file and 
            self.config.patterns_cache_file.exists()):
            
            logger.info(f"Loading cached patterns from {self.config.patterns_cache_file}")
            self._load_cached_patterns()
        else:
            # Train from scratch
            self._train_analytics()
            
            # Cache patterns if requested
            if self.config.cache_patterns and self.config.patterns_cache_file:
                self._save_patterns()
        
        # Initialize medical rules engine
        self._init_rules_engine()
        
        # Extract drug vocabulary
        if self.analytics:
            self.drug_vocabulary = self.analytics.drug_vocabulary
        
        self.is_trained = True
        logger.info(f"✅ Training complete: {len(self.drug_vocabulary)} drugs in vocabulary")
    
    def predict_probabilities(self, 
                            clinical_context: ClinicalContext,
                            current_medications: Optional[List[str]] = None) -> np.ndarray:
        """
        Predict drug probabilities for given clinical context.
        
        Args:
            clinical_context: Patient clinical context
            current_medications: Current patient medications (for interaction checking)
            
        Returns:
            Probability distribution over drug vocabulary
        """
        if not self.is_trained:
            raise ValueError("Predictor not trained. Call train() first.")
        
        if not self.drug_vocabulary:
            raise ValueError("Drug vocabulary is empty")
        
        # Generate base probabilities from statistical patterns
        base_probs = self._generate_base_probabilities(clinical_context)
        
        # Apply medical safety rules
        safe_probs = self._apply_safety_rules(
            probabilities=base_probs,
            clinical_context=clinical_context,
            current_medications=current_medications or []
        )
        
        # Ensure minimum probabilities
        safe_probs = np.maximum(safe_probs, self.config.min_probability)
        
        # Final normalization
        safe_probs = safe_probs / np.sum(safe_probs)
        
        logger.debug(f"Generated probabilities for context with {len(clinical_context.diagnoses or [])} diagnoses")
        
        return safe_probs
    
    def predict_top_k(self,
                     clinical_context: ClinicalContext,
                     k: int = 5,
                     current_medications: Optional[List[str]] = None) -> List[Tuple[str, float]]:
        """
        Predict top-k most likely drugs with probabilities.
        
        Args:
            clinical_context: Patient clinical context
            k: Number of top drugs to return
            current_medications: Current patient medications
            
        Returns:
            List of (drug_name, probability) tuples, sorted by probability (descending)
        """
        probabilities = self.predict_probabilities(clinical_context, current_medications)
        
        # Get top-k indices
        top_indices = np.argsort(probabilities)[-k:][::-1]
        
        # Create result list
        top_drugs = [
            (self.drug_vocabulary[idx], probabilities[idx])
            for idx in top_indices
        ]
        
        return top_drugs
    
    def _train_analytics(self) -> None:
        """Train prescription analytics from MIMIC-IV data."""
        logger.info("Training prescription analytics...")
        
        # Create analytics instance
        self.analytics = create_prescription_analytics(
            data_root=self.config.mimic_data_root,
            min_pattern_support=self.config.min_pattern_support,
            smoothing_alpha=self.config.smoothing_alpha
        )
        
        # Analyze patterns
        self.analytics.analyze_prescription_patterns()
        
        logger.info("Analytics training complete")
    
    def _init_rules_engine(self) -> None:
        """Initialize medical rules engine."""
        logger.info("Initializing medical rules engine...")
        
        self.rules_engine = create_medical_rules_engine(
            rules_file=self.config.custom_rules_file
        )
        
        logger.info("Rules engine initialized")
    
    def _generate_base_probabilities(self, clinical_context: ClinicalContext) -> np.ndarray:
        """Generate base probabilities from statistical patterns."""
        if not self.analytics:
            raise ValueError("Analytics not initialized")
        
        return self.analytics.generate_probability_distribution(
            clinical_context=clinical_context,
            method=self.config.method
        )
    
    def _apply_safety_rules(self,
                          probabilities: np.ndarray,
                          clinical_context: ClinicalContext,
                          current_medications: List[str]) -> np.ndarray:
        """Apply medical safety rules to probabilities."""
        if not self.rules_engine:
            return probabilities
        
        modified_probs = probabilities.copy()
        
        # Apply contraindication rules
        if self.config.apply_contraindications:
            modified_probs = self.rules_engine.apply_contraindication_rules(
                probabilities=modified_probs,
                drug_vocabulary=self.drug_vocabulary,
                clinical_context=clinical_context
            )
        
        # Apply interaction rules
        if self.config.apply_interactions and current_medications:
            modified_probs = self.rules_engine.apply_interaction_rules(
                probabilities=modified_probs,
                drug_vocabulary=self.drug_vocabulary,
                current_medications=current_medications
            )
        
        # Blend with safety considerations
        if self.config.safety_weight > 0:
            safety_scores = np.array([
                self.rules_engine.get_safety_score(drug, clinical_context)
                for drug in self.drug_vocabulary
            ])
            
            # Weighted combination
            modified_probs = (
                (1 - self.config.safety_weight) * modified_probs +
                self.config.safety_weight * safety_scores * modified_probs
            )
        
        return modified_probs
    
    def _save_patterns(self) -> None:
        """Save trained patterns to cache."""
        if not self.analytics or not self.config.patterns_cache_file:
            return
        
        self.analytics.save_patterns(self.config.patterns_cache_file)
        logger.info(f"Cached patterns to {self.config.patterns_cache_file}")
    
    def _load_cached_patterns(self) -> None:
        """Load patterns from cache."""
        if not self.config.patterns_cache_file:
            return
        
        # Create analytics instance
        self.analytics = create_prescription_analytics(
            data_root=self.config.mimic_data_root,
            min_pattern_support=self.config.min_pattern_support,
            smoothing_alpha=self.config.smoothing_alpha
        )
        
        # Load cached patterns
        self.analytics.load_patterns(self.config.patterns_cache_file)
        logger.info(f"Loaded cached patterns from {self.config.patterns_cache_file}")
    
    def get_prediction_summary(self, clinical_context: ClinicalContext) -> Dict:
        """
        Get detailed prediction summary for analysis.
        
        Args:
            clinical_context: Patient clinical context
            
        Returns:
            Dictionary with prediction details and explanations
        """
        if not self.is_trained:
            raise ValueError("Predictor not trained")
        
        # Get predictions
        probabilities = self.predict_probabilities(clinical_context)
        top_drugs = self.predict_top_k(clinical_context, k=10)
        
        # Generate summary
        summary = {
            'clinical_context': {
                'diagnoses': clinical_context.diagnoses or [],
                'age': clinical_context.age,
                'gender': clinical_context.gender
            },
            'prediction_method': self.config.method,
            'total_drugs': len(self.drug_vocabulary),
            'top_predictions': top_drugs,
            'probability_statistics': {
                'max': float(np.max(probabilities)),
                'min': float(np.min(probabilities)),
                'mean': float(np.mean(probabilities)),
                'std': float(np.std(probabilities)),
                'entropy': float(-np.sum(probabilities * np.log(probabilities + 1e-10)))
            }
        }
        
        # Add safety validation if available
        if self.rules_engine:
            top_drug_names = [drug for drug, _ in top_drugs]
            validation = self.rules_engine.validate_prescription_set(
                drugs=top_drug_names,
                clinical_context=clinical_context
            )
            summary['safety_validation'] = validation
        
        return summary
    
    def get_drug_vocabulary(self) -> List[str]:
        """Get the drug vocabulary."""
        return self.drug_vocabulary.copy()
    
    def is_drug_contraindicated(self, drug_name: str, clinical_context: ClinicalContext) -> bool:
        """
        Check if a drug is contraindicated for the patient.
        
        Args:
            drug_name: Drug name
            clinical_context: Patient clinical context
            
        Returns:
            True if contraindicated, False otherwise
        """
        if not self.rules_engine:
            return False
        
        safety_score = self.rules_engine.get_safety_score(drug_name, clinical_context)
        return safety_score < 0.5  # Arbitrary threshold for "contraindicated"


def create_realistic_predictor(config: Optional[PredictionConfig] = None) -> RealisticDrugPredictor:
    """
    Create and train realistic drug predictor.
    
    Args:
        config: Optional configuration (uses defaults if None)
        
    Returns:
        Trained RealisticDrugPredictor instance
    """
    if config is None:
        config = PredictionConfig()
    
    predictor = RealisticDrugPredictor(config)
    return predictor


def create_trained_predictor(config: Optional[PredictionConfig] = None,
                           cache_file: Optional[str] = None) -> RealisticDrugPredictor:
    """
    Create and automatically train realistic drug predictor.
    
    Args:
        config: Optional configuration
        cache_file: Optional cache file path
        
    Returns:
        Trained RealisticDrugPredictor instance
    """
    if config is None:
        config = PredictionConfig()
    
    # Set cache file if provided
    if cache_file:
        config.patterns_cache_file = Path(cache_file)
        config.cache_patterns = True
    
    predictor = create_realistic_predictor(config)
    predictor.train()
    
    return predictor

