"""
ICD Diagnosis to Drug Recommendation Mapper

Converts ICD diagnosis codes to clinically appropriate drug recommendations
based on standard treatment guidelines. This bridges MIMIC's diagnostic data
with drug recommendation tasks.

Authors: Research Team
Date: 2024
"""

import logging
from typing import Dict, List, Optional, Set
from dataclasses import dataclass
import numpy as np
from collections import defaultdict

logger = logging.getLogger(__name__)


@dataclass
class DrugRecommendation:
    """Drug recommendation with confidence and context."""
    drug_name: str
    confidence: float  # 0-1, based on clinical guidelines
    indication: str    # Primary indication
    line_of_therapy: str  # "first-line", "second-line", "alternative"
    contraindications: List[str] = None


class ICDToDrugMapper:
    """
    Maps ICD-9 diagnosis codes to appropriate drug recommendations.
    
    Based on standard clinical guidelines and treatment protocols.
    Provides realistic P(drug|diagnosis) distributions for CNCRC experiments.
    """
    
    def __init__(self):
        """Initialize the ICD to drug mapper with clinical guidelines."""
        self.icd_drug_mapping: Dict[str, List[DrugRecommendation]] = {}
        self.drug_universe: Set[str] = set()
        
        # Load standard mappings
        self._load_clinical_guidelines()
        
        logger.info(f"Initialized ICD mapper with {len(self.icd_drug_mapping)} diagnoses, "
                   f"{len(self.drug_universe)} unique drugs")
    
    def _load_clinical_guidelines(self) -> None:
        """Load evidence-based ICD diagnosis to drug mappings."""
        
        # Essential hypertension (4019)
        self.icd_drug_mapping['4019'] = [
            DrugRecommendation('lisinopril', 0.8, 'hypertension', 'first-line'),
            DrugRecommendation('amlodipine', 0.7, 'hypertension', 'first-line'),
            DrugRecommendation('hydrochlorothiazide', 0.7, 'hypertension', 'first-line'),
            DrugRecommendation('metoprolol', 0.6, 'hypertension', 'first-line'),
            DrugRecommendation('losartan', 0.6, 'hypertension', 'first-line'),
            DrugRecommendation('atorvastatin', 0.4, 'cardiovascular risk', 'adjunct'),
            DrugRecommendation('aspirin', 0.3, 'cardiovascular prevention', 'adjunct'),
        ]
        
        # Diabetes mellitus (25000)
        self.icd_drug_mapping['25000'] = [
            DrugRecommendation('metformin', 0.9, 'diabetes', 'first-line'),
            DrugRecommendation('insulin', 0.7, 'diabetes', 'variable'),
            DrugRecommendation('glyburide', 0.5, 'diabetes', 'second-line'),
            DrugRecommendation('glipizide', 0.5, 'diabetes', 'second-line'),
            DrugRecommendation('sitagliptin', 0.4, 'diabetes', 'second-line'),
            DrugRecommendation('lisinopril', 0.6, 'diabetic nephropathy', 'adjunct'),
            DrugRecommendation('atorvastatin', 0.5, 'diabetic dyslipidemia', 'adjunct'),
        ]
        
        # Atrial fibrillation (42731)
        self.icd_drug_mapping['42731'] = [
            DrugRecommendation('warfarin', 0.6, 'anticoagulation', 'traditional'),
            DrugRecommendation('apixaban', 0.8, 'anticoagulation', 'first-line'),
            DrugRecommendation('rivaroxaban', 0.7, 'anticoagulation', 'first-line'),
            DrugRecommendation('dabigatran', 0.6, 'anticoagulation', 'first-line'),
            DrugRecommendation('metoprolol', 0.7, 'rate control', 'first-line'),
            DrugRecommendation('diltiazem', 0.6, 'rate control', 'alternative'),
            DrugRecommendation('amiodarone', 0.4, 'rhythm control', 'second-line'),
        ]
        
        # Congestive heart failure (4280)
        self.icd_drug_mapping['4280'] = [
            DrugRecommendation('furosemide', 0.9, 'volume overload', 'first-line'),
            DrugRecommendation('lisinopril', 0.8, 'ACE inhibition', 'first-line'),
            DrugRecommendation('metoprolol', 0.7, 'beta blockade', 'first-line'),
            DrugRecommendation('spironolactone', 0.6, 'aldosterone antagonist', 'first-line'),
            DrugRecommendation('digoxin', 0.4, 'inotropic support', 'second-line'),
            DrugRecommendation('losartan', 0.5, 'ARB (if ACE intolerant)', 'alternative'),
        ]
        
        # Depressive disorder (311)
        self.icd_drug_mapping['311'] = [
            DrugRecommendation('sertraline', 0.8, 'depression', 'first-line'),
            DrugRecommendation('fluoxetine', 0.7, 'depression', 'first-line'),
            DrugRecommendation('escitalopram', 0.7, 'depression', 'first-line'),
            DrugRecommendation('citalopram', 0.6, 'depression', 'first-line'),
            DrugRecommendation('bupropion', 0.5, 'depression', 'alternative'),
            DrugRecommendation('mirtazapine', 0.4, 'depression', 'second-line'),
        ]
        
        # Pneumonia (486)
        self.icd_drug_mapping['486'] = [
            DrugRecommendation('azithromycin', 0.7, 'atypical pneumonia', 'first-line'),
            DrugRecommendation('amoxicillin', 0.6, 'bacterial pneumonia', 'first-line'),
            DrugRecommendation('levofloxacin', 0.6, 'pneumonia', 'first-line'),
            DrugRecommendation('ceftriaxone', 0.5, 'severe pneumonia', 'hospital'),
            DrugRecommendation('doxycycline', 0.5, 'atypical pneumonia', 'alternative'),
        ]
        
        # Acute renal failure (5849)
        self.icd_drug_mapping['5849'] = [
            DrugRecommendation('furosemide', 0.6, 'volume management', 'situational'),
            DrugRecommendation('sodium_bicarbonate', 0.4, 'acidosis', 'situational'),
            # Note: Many drugs are avoided in renal failure
        ]
        
        # COPD (496)
        self.icd_drug_mapping['496'] = [
            DrugRecommendation('albuterol', 0.9, 'bronchodilation', 'first-line'),
            DrugRecommendation('ipratropium', 0.7, 'bronchodilation', 'first-line'),
            DrugRecommendation('prednisone', 0.6, 'exacerbation', 'acute'),
            DrugRecommendation('azithromycin', 0.5, 'exacerbation', 'acute'),
            DrugRecommendation('tiotropium', 0.5, 'maintenance', 'second-line'),
        ]
        
        # Build drug universe
        for recommendations in self.icd_drug_mapping.values():
            for rec in recommendations:
                self.drug_universe.add(rec.drug_name)
    
    def get_drug_recommendations(self, icd_code: str) -> List[DrugRecommendation]:
        """Get drug recommendations for a specific ICD code."""
        return self.icd_drug_mapping.get(icd_code, [])
    
    def get_all_drugs(self) -> List[str]:
        """Get list of all drugs in the universe."""
        return sorted(list(self.drug_universe))
    
    def generate_drug_probabilities(self, 
                                   icd_codes: List[str], 
                                   age: Optional[int] = None,
                                   gender: Optional[str] = None) -> Dict[str, float]:
        """
        Generate drug probability distribution for given ICD codes.
        
        Args:
            icd_codes: List of ICD-9 diagnosis codes
            age: Patient age (for age-specific adjustments)
            gender: Patient gender (for gender-specific adjustments)
            
        Returns:
            Dictionary mapping drug names to probabilities
        """
        drug_scores = defaultdict(float)
        
        # Aggregate recommendations from all diagnoses
        for icd_code in icd_codes:
            recommendations = self.get_drug_recommendations(icd_code)
            for rec in recommendations:
                # Base confidence from clinical guidelines
                score = rec.confidence
                
                # Age adjustments
                if age is not None:
                    score = self._apply_age_adjustments(rec.drug_name, score, age)
                
                # Gender adjustments
                if gender is not None:
                    score = self._apply_gender_adjustments(rec.drug_name, score, gender)
                
                # Accumulate scores (multiple diagnoses can support same drug)
                drug_scores[rec.drug_name] += score
        
        # Normalize to probability distribution
        if not drug_scores:
            # If no mappings found, return uniform distribution
            all_drugs = self.get_all_drugs()
            return {drug: 1.0/len(all_drugs) for drug in all_drugs}
        
        # Convert to probabilities
        total_score = sum(drug_scores.values())
        probabilities = {drug: score/total_score for drug, score in drug_scores.items()}
        
        # Add small probabilities for unmapped drugs
        all_drugs = self.get_all_drugs()
        for drug in all_drugs:
            if drug not in probabilities:
                probabilities[drug] = 0.001  # Small baseline probability
        
        # Final normalization
        total_prob = sum(probabilities.values())
        probabilities = {drug: prob/total_prob for drug, prob in probabilities.items()}
        
        return probabilities
    
    def _apply_age_adjustments(self, drug_name: str, score: float, age: int) -> float:
        """Apply age-specific drug score adjustments."""
        
        # Pediatric considerations (age < 18)
        if age < 18:
            pediatric_avoid = ['aspirin', 'warfarin', 'atorvastatin']
            if drug_name in pediatric_avoid:
                score *= 0.1
        
        # Elderly considerations (age >= 65)
        elif age >= 65:
            # Be cautious with certain drugs in elderly
            elderly_caution = ['digoxin', 'amiodarone', 'glyburide']
            if drug_name in elderly_caution:
                score *= 0.7
                
            # Prefer certain drugs in elderly
            elderly_preferred = ['lisinopril', 'metoprolol']
            if drug_name in elderly_preferred:
                score *= 1.2
        
        return score
    
    def _apply_gender_adjustments(self, drug_name: str, score: float, gender: str) -> float:
        """Apply gender-specific drug score adjustments."""
        
        # Currently minimal gender adjustments
        # Can be expanded based on clinical guidelines
        
        return score
    
    def get_mapping_statistics(self) -> Dict[str, any]:
        """Get statistics about the ICD to drug mappings."""
        stats = {
            'total_icd_codes': len(self.icd_drug_mapping),
            'total_drugs': len(self.drug_universe),
            'avg_drugs_per_icd': np.mean([len(recs) for recs in self.icd_drug_mapping.values()]),
            'icd_codes': list(self.icd_drug_mapping.keys()),
            'all_drugs': sorted(list(self.drug_universe))
        }
        return stats


def create_icd_to_drug_mapper() -> ICDToDrugMapper:
    """Create and return an ICD to drug mapper instance."""
    return ICDToDrugMapper()


# For testing
if __name__ == "__main__":
    mapper = create_icd_to_drug_mapper()
    stats = mapper.get_mapping_statistics()
    print(f"ICD Mapper Statistics: {stats}")
    
    # Test example
    test_icd_codes = ['4019', '25000']  # Hypertension + Diabetes
    probs = mapper.generate_drug_probabilities(test_icd_codes, age=65, gender='M')
    
    print(f"\nTop drug recommendations for HTN+DM, 65M:")
    sorted_drugs = sorted(probs.items(), key=lambda x: x[1], reverse=True)
    for drug, prob in sorted_drugs[:10]:
        print(f"  {drug}: {prob:.4f}")
