"""
Medical Rules Engine for Drug Prescription

This module implements medical rules and contraindications to constrain
drug probability distributions based on clinical guidelines.

Authors: Research Team  
Date: 2024
"""

import logging
from typing import Dict, List, Set, Optional, Tuple
from dataclasses import dataclass
from pathlib import Path
import json
import numpy as np

from ..core.data_structures import ClinicalContext

logger = logging.getLogger(__name__)


@dataclass 
class DrugContraindication:
    """Drug contraindication rule."""
    
    drug_name: str
    contraindicated_conditions: List[str]
    contraindicated_age_ranges: List[Tuple[Optional[float], Optional[float]]]  # (min_age, max_age)
    contraindicated_genders: List[str]
    severity: str  # 'absolute', 'relative', 'caution'
    reason: str
    

@dataclass
class DrugInteractionRule:
    """Drug-drug interaction rule."""
    
    drug1: str
    drug2: str
    interaction_type: str  # 'major', 'moderate', 'minor'
    effect: str
    risk_multiplier: float  # How much to reduce probability


class MedicalRulesEngine:
    """
    Medical rules engine for constraining drug probabilities.
    
    Applies clinical guidelines and contraindications to ensure
    generated probabilities respect medical safety constraints.
    """
    
    def __init__(self, rules_file: Optional[Path] = None):
        """
        Initialize medical rules engine.
        
        Args:
            rules_file: Optional file with custom rules
        """
        self.contraindications: List[DrugContraindication] = []
        self.interaction_rules: List[DrugInteractionRule] = []
        self.drug_safety_profiles: Dict[str, Dict[str, float]] = {}
        
        # Load default rules
        self._load_default_rules()
        
        # Load custom rules if provided
        if rules_file and rules_file.exists():
            self._load_custom_rules(rules_file)
        
        logger.info(f"Initialized medical rules: {len(self.contraindications)} contraindications, "
                   f"{len(self.interaction_rules)} interactions")
    
    def _load_default_rules(self) -> None:
        """Load default medical rules and contraindications."""
        
        # Common pediatric contraindications
        self.contraindications.extend([
            DrugContraindication(
                drug_name="Aspirin",
                contraindicated_conditions=["Reye syndrome risk"],
                contraindicated_age_ranges=[(None, 18)],
                contraindicated_genders=[],
                severity="absolute",
                reason="Risk of Reye syndrome in children"
            ),
            DrugContraindication(
                drug_name="Tetracycline", 
                contraindicated_conditions=[],
                contraindicated_age_ranges=[(None, 8)],
                contraindicated_genders=[],
                severity="absolute",
                reason="Tooth discoloration in children"
            ),
        ])
        
        # Pregnancy contraindications
        self.contraindications.extend([
            DrugContraindication(
                drug_name="Warfarin",
                contraindicated_conditions=["pregnancy"],
                contraindicated_age_ranges=[],
                contraindicated_genders=[],
                severity="absolute", 
                reason="Teratogenic effects"
            ),
            DrugContraindication(
                drug_name="ACE inhibitors",
                contraindicated_conditions=["pregnancy"],
                contraindicated_age_ranges=[],
                contraindicated_genders=[],
                severity="absolute",
                reason="Fetal kidney development"
            ),
        ])
        
        # Elderly contraindications
        self.contraindications.extend([
            DrugContraindication(
                drug_name="Benzodiazepines",
                contraindicated_conditions=[],
                contraindicated_age_ranges=[(75, None)],
                contraindicated_genders=[],
                severity="relative",
                reason="Fall risk in elderly"
            ),
        ])
        
        # Condition-specific contraindications
        self.contraindications.extend([
            DrugContraindication(
                drug_name="Beta blockers", 
                contraindicated_conditions=["asthma", "COPD"],
                contraindicated_age_ranges=[],
                contraindicated_genders=[],
                severity="relative",
                reason="Bronchospasm risk"
            ),
            DrugContraindication(
                drug_name="NSAIDs",
                contraindicated_conditions=["chronic kidney disease", "heart failure"],
                contraindicated_age_ranges=[],
                contraindicated_genders=[],
                severity="relative", 
                reason="Kidney function and fluid retention"
            ),
        ])
        
        # Drug interaction rules
        self.interaction_rules.extend([
            DrugInteractionRule(
                drug1="Warfarin",
                drug2="Aspirin", 
                interaction_type="major",
                effect="Increased bleeding risk",
                risk_multiplier=0.3
            ),
            DrugInteractionRule(
                drug1="Digoxin",
                drug2="Amiodarone",
                interaction_type="major",
                effect="Digoxin toxicity",
                risk_multiplier=0.2
            ),
            DrugInteractionRule(
                drug1="Simvastatin",
                drug2="Clarithromycin",
                interaction_type="major", 
                effect="Rhabdomyolysis risk",
                risk_multiplier=0.1
            ),
        ])
        
        logger.info("Loaded default medical rules")
    
    def _load_custom_rules(self, rules_file: Path) -> None:
        """Load custom rules from JSON file."""
        try:
            with open(rules_file, 'r') as f:
                custom_rules = json.load(f)
            
            # Load custom contraindications
            if 'contraindications' in custom_rules:
                for rule_data in custom_rules['contraindications']:
                    contraindication = DrugContraindication(**rule_data)
                    self.contraindications.append(contraindication)
            
            # Load custom interactions
            if 'interactions' in custom_rules:
                for rule_data in custom_rules['interactions']:
                    interaction = DrugInteractionRule(**rule_data)
                    self.interaction_rules.append(interaction)
            
            logger.info(f"Loaded custom rules from {rules_file}")
            
        except Exception as e:
            logger.warning(f"Failed to load custom rules from {rules_file}: {e}")
    
    def apply_contraindication_rules(self, 
                                   probabilities: np.ndarray,
                                   drug_vocabulary: List[str],
                                   clinical_context: ClinicalContext) -> np.ndarray:
        """
        Apply contraindication rules to modify drug probabilities.
        
        Args:
            probabilities: Original drug probabilities 
            drug_vocabulary: List of drug names
            clinical_context: Patient clinical context
            
        Returns:
            Modified probabilities with contraindications applied
        """
        modified_probs = probabilities.copy()
        
        for contraindication in self.contraindications:
            # Check if drug is in vocabulary
            matching_drugs = self._find_matching_drugs(contraindication.drug_name, drug_vocabulary)
            
            for drug_idx in matching_drugs:
                if self._is_contraindicated(contraindication, clinical_context):
                    # Apply contraindication based on severity
                    if contraindication.severity == "absolute":
                        modified_probs[drug_idx] *= 0.01  # Nearly eliminate
                    elif contraindication.severity == "relative":
                        modified_probs[drug_idx] *= 0.3   # Strongly reduce
                    elif contraindication.severity == "caution":
                        modified_probs[drug_idx] *= 0.7   # Moderately reduce
                    
                    logger.debug(f"Applied {contraindication.severity} contraindication: "
                               f"{contraindication.drug_name} for {clinical_context}")
        
        # Renormalize
        if np.sum(modified_probs) > 0:
            modified_probs = modified_probs / np.sum(modified_probs)
        
        return modified_probs
    
    def apply_interaction_rules(self,
                              probabilities: np.ndarray,
                              drug_vocabulary: List[str], 
                              current_medications: List[str]) -> np.ndarray:
        """
        Apply drug interaction rules to modify probabilities.
        
        Args:
            probabilities: Original drug probabilities
            drug_vocabulary: List of drug names
            current_medications: List of current patient medications
            
        Returns:
            Modified probabilities with interactions considered
        """
        modified_probs = probabilities.copy()
        
        for interaction in self.interaction_rules:
            # Check if either drug is in current medications
            if (interaction.drug1 in current_medications or 
                interaction.drug2 in current_medications):
                
                # Find the other drug in vocabulary and reduce its probability
                other_drug = interaction.drug2 if interaction.drug1 in current_medications else interaction.drug1
                matching_drugs = self._find_matching_drugs(other_drug, drug_vocabulary)
                
                for drug_idx in matching_drugs:
                    modified_probs[drug_idx] *= interaction.risk_multiplier
                    logger.debug(f"Applied interaction rule: {interaction.drug1} + {interaction.drug2}")
        
        # Renormalize
        if np.sum(modified_probs) > 0:
            modified_probs = modified_probs / np.sum(modified_probs)
        
        return modified_probs
    
    def _find_matching_drugs(self, target_drug: str, drug_vocabulary: List[str]) -> List[int]:
        """Find indices of drugs that match the target (exact or partial match)."""
        matching_indices = []
        target_lower = target_drug.lower()
        
        for i, drug in enumerate(drug_vocabulary):
            drug_lower = drug.lower()
            if (target_lower == drug_lower or 
                target_lower in drug_lower or 
                drug_lower in target_lower):
                matching_indices.append(i)
        
        return matching_indices
    
    def _is_contraindicated(self, 
                          contraindication: DrugContraindication,
                          clinical_context: ClinicalContext) -> bool:
        """Check if a contraindication applies to the clinical context."""
        
        # Check age contraindications
        if contraindication.contraindicated_age_ranges and clinical_context.age is not None:
            for min_age, max_age in contraindication.contraindicated_age_ranges:
                if self._age_in_range(clinical_context.age, min_age, max_age):
                    return True
        
        # Check gender contraindications
        if (contraindication.contraindicated_genders and 
            clinical_context.gender and
            clinical_context.gender.lower() in [g.lower() for g in contraindication.contraindicated_genders]):
            return True
        
        # Check condition contraindications
        if contraindication.contraindicated_conditions and clinical_context.diagnoses:
            patient_conditions = [d.lower() for d in clinical_context.diagnoses]
            contraindicated_conditions = [c.lower() for c in contraindication.contraindicated_conditions]
            
            for condition in contraindicated_conditions:
                if any(condition in patient_condition for patient_condition in patient_conditions):
                    return True
        
        return False
    
    def _age_in_range(self, age: float, min_age: Optional[float], max_age: Optional[float]) -> bool:
        """Check if age falls within specified range."""
        if min_age is not None and age < min_age:
            return False
        if max_age is not None and age > max_age:
            return False
        return True
    
    def get_safety_score(self, 
                        drug_name: str,
                        clinical_context: ClinicalContext) -> float:
        """
        Get safety score for a drug given clinical context.
        
        Args:
            drug_name: Drug name
            clinical_context: Patient clinical context
            
        Returns:
            Safety score (0-1, higher is safer)
        """
        safety_score = 1.0
        
        # Check contraindications
        for contraindication in self.contraindications:
            if (drug_name.lower() in contraindication.drug_name.lower() and
                self._is_contraindicated(contraindication, clinical_context)):
                
                if contraindication.severity == "absolute":
                    safety_score *= 0.1
                elif contraindication.severity == "relative":
                    safety_score *= 0.5  
                elif contraindication.severity == "caution":
                    safety_score *= 0.8
        
        return safety_score
    
    def validate_prescription_set(self,
                                drugs: List[str],
                                clinical_context: ClinicalContext) -> Dict[str, List[str]]:
        """
        Validate a set of prescribed drugs for contraindications and interactions.
        
        Args:
            drugs: List of prescribed drugs
            clinical_context: Patient clinical context
            
        Returns:
            Dictionary with 'warnings', 'contraindications', 'interactions'
        """
        validation_results = {
            'warnings': [],
            'contraindications': [],
            'interactions': []
        }
        
        # Check contraindications
        for drug in drugs:
            for contraindication in self.contraindications:
                if (drug.lower() in contraindication.drug_name.lower() and
                    self._is_contraindicated(contraindication, clinical_context)):
                    
                    validation_results['contraindications'].append({
                        'drug': drug,
                        'reason': contraindication.reason,
                        'severity': contraindication.severity
                    })
        
        # Check interactions
        for i, drug1 in enumerate(drugs):
            for j, drug2 in enumerate(drugs[i+1:], i+1):
                for interaction in self.interaction_rules:
                    if ((drug1.lower() in interaction.drug1.lower() and 
                         drug2.lower() in interaction.drug2.lower()) or
                        (drug1.lower() in interaction.drug2.lower() and 
                         drug2.lower() in interaction.drug1.lower())):
                        
                        validation_results['interactions'].append({
                            'drug1': drug1,
                            'drug2': drug2,
                            'type': interaction.interaction_type,
                            'effect': interaction.effect
                        })
        
        return validation_results


def create_medical_rules_engine(rules_file: Optional[Path] = None) -> MedicalRulesEngine:
    """
    Create medical rules engine.
    
    Args:
        rules_file: Optional custom rules file
        
    Returns:
        Initialized MedicalRulesEngine
    """
    return MedicalRulesEngine(rules_file=rules_file)

