"""
MIMIC-IV Prescription Pattern Analytics

This module analyzes MIMIC-IV data to extract statistical patterns for drug prescription
prediction, providing a realistic alternative to random probability generation.

Authors: Research Team
Date: 2024
"""

import json
import logging
from typing import Dict, List, Tuple, Optional, Any, Set
from dataclasses import dataclass
from pathlib import Path
import numpy as np
import pandas as pd
from collections import defaultdict, Counter
import pickle

from ..core.data_structures import ClinicalContext
from .mimic_loader import MimicDataLoader, MimicDataConfig

logger = logging.getLogger(__name__)


@dataclass
class PrescriptionPattern:
    """Statistical prescription pattern from MIMIC-IV."""
    
    condition: str  # Diagnosis or condition
    drug_probabilities: Dict[str, float]  # Drug -> probability
    patient_count: int  # Number of patients with this pattern
    confidence: float  # Statistical confidence (0-1)
    

@dataclass
class ConditionalProbability:
    """Conditional probability P(drug|feature_set)."""
    
    feature_combination: Tuple[str, ...]  # (diagnosis, age_group, gender, etc.)
    drug_distribution: Dict[str, float]  # Normalized probabilities
    sample_size: int
    

class PrescriptionAnalytics:
    """
    Analyzes MIMIC-IV prescription patterns to generate realistic P(drug|clinical_context).
    
    This class extracts statistical relationships from real prescription data to provide
    a foundation for generating meaningful drug probabilities based on clinical features.
    """
    
    def __init__(self, 
                 mimic_loader: MimicDataLoader,
                 min_pattern_support: int = 10,
                 smoothing_alpha: float = 0.1):
        """
        Initialize prescription analytics.
        
        Args:
            mimic_loader: MIMIC data loader
            min_pattern_support: Minimum patient count for reliable patterns
            smoothing_alpha: Laplace smoothing parameter
        """
        self.mimic_loader = mimic_loader
        self.min_pattern_support = min_pattern_support
        self.smoothing_alpha = smoothing_alpha
        
        # Learned patterns
        self.diagnosis_drug_patterns: Dict[str, PrescriptionPattern] = {}
        self.conditional_probabilities: List[ConditionalProbability] = []
        self.drug_vocabulary: List[str] = []
        self.baseline_drug_frequencies: Dict[str, float] = {}
        
        logger.info(f"Initialized PrescriptionAnalytics with support={min_pattern_support}")
    
    def analyze_prescription_patterns(self) -> None:
        """
        Analyze MIMIC-IV data to extract prescription patterns.
        
        This extracts statistical relationships between clinical features and drug usage.
        """
        logger.info("🔍 Analyzing MIMIC-IV prescription patterns...")
        
        # Load all training data
        train_data = self.mimic_loader.load_train_data()
        logger.info(f"Loaded {len(train_data)} training samples")
        
        # Extract drug vocabulary
        self._extract_drug_vocabulary(train_data)
        
        # Extract baseline drug frequencies
        self._compute_baseline_frequencies(train_data)
        
        # Extract diagnosis-drug patterns
        self._extract_diagnosis_patterns(train_data)
        
        # Extract conditional probabilities
        self._extract_conditional_probabilities(train_data)
        
        logger.info(f"✅ Analysis complete: {len(self.drug_vocabulary)} drugs, "
                   f"{len(self.diagnosis_drug_patterns)} diagnosis patterns")
    
    def _extract_drug_vocabulary(self, data: pd.DataFrame) -> None:
        """Extract unique drug vocabulary from the dataset."""
        all_drugs = set()
        
        for _, row in data.iterrows():
            labels = self._parse_labels(row['labels'])
            all_drugs.update(labels)
        
        # Filter out empty/invalid drug names
        valid_drugs = [drug for drug in all_drugs 
                      if drug and isinstance(drug, str) and len(drug.strip()) > 0]
        
        self.drug_vocabulary = sorted(valid_drugs)
        logger.info(f"Extracted drug vocabulary: {len(self.drug_vocabulary)} drugs")
    
    def _compute_baseline_frequencies(self, data: pd.DataFrame) -> None:
        """Compute baseline drug prescription frequencies."""
        drug_counts = Counter()
        total_prescriptions = 0
        
        for _, row in data.iterrows():
            labels = self._parse_labels(row['labels'])
            drug_counts.update(labels)
            total_prescriptions += len(labels)
        
        # Normalize to probabilities
        self.baseline_drug_frequencies = {
            drug: count / total_prescriptions if total_prescriptions > 0 else 0
            for drug, count in drug_counts.items()
        }
        
        logger.info(f"Computed baseline frequencies for {len(self.baseline_drug_frequencies)} drugs")
    
    def _extract_diagnosis_patterns(self, data: pd.DataFrame) -> None:
        """Extract diagnosis-specific prescription patterns."""
        diagnosis_prescriptions = defaultdict(list)
        
        # Group prescriptions by diagnosis
        for _, row in data.iterrows():
            diagnoses = self._extract_diagnoses(row)
            drugs = self._parse_labels(row['labels'])
            
            for diagnosis in diagnoses:
                diagnosis_prescriptions[diagnosis].extend(drugs)
        
        # Convert to patterns with statistics
        for diagnosis, drug_list in diagnosis_prescriptions.items():
            if len(drug_list) >= self.min_pattern_support:
                drug_counts = Counter(drug_list)
                total_drugs = len(drug_list)
                
                drug_probabilities = {
                    drug: count / total_drugs
                    for drug, count in drug_counts.items()
                }
                
                pattern = PrescriptionPattern(
                    condition=diagnosis,
                    drug_probabilities=drug_probabilities,
                    patient_count=len(set(drug_list)),  # Approximate
                    confidence=min(1.0, len(drug_list) / 100)  # Confidence based on sample size
                )
                
                self.diagnosis_drug_patterns[diagnosis] = pattern
        
        logger.info(f"Extracted {len(self.diagnosis_drug_patterns)} diagnosis patterns")
    
    def _extract_conditional_probabilities(self, data: pd.DataFrame) -> None:
        """Extract conditional probabilities P(drug|feature_combination)."""
        feature_combinations = defaultdict(list)
        
        # Group by feature combinations
        for _, row in data.iterrows():
            features = self._extract_features(row)
            drugs = self._parse_labels(row['labels'])
            
            feature_key = tuple(sorted(features))
            feature_combinations[feature_key].extend(drugs)
        
        # Convert to conditional probabilities
        for feature_combo, drug_list in feature_combinations.items():
            if len(drug_list) >= self.min_pattern_support:
                drug_counts = Counter(drug_list)
                total_drugs = len(drug_list)
                
                drug_distribution = {
                    drug: count / total_drugs
                    for drug, count in drug_counts.items()
                }
                
                conditional_prob = ConditionalProbability(
                    feature_combination=feature_combo,
                    drug_distribution=drug_distribution,
                    sample_size=len(drug_list)
                )
                
                self.conditional_probabilities.append(conditional_prob)
        
        logger.info(f"Extracted {len(self.conditional_probabilities)} conditional patterns")
    
    def _parse_labels(self, labels) -> List[str]:
        """Parse drug labels from various formats."""
        # Handle None/NaN cases first
        if labels is None:
            return []
        
        # Handle pandas/numpy arrays and scalar NaN
        try:
            if pd.isna(labels):
                return []
        except (TypeError, ValueError):
            pass  # Not a pandas-compatible type
        
        # Handle string inputs
        if isinstance(labels, str):
            try:
                parsed = json.loads(labels)
                if isinstance(parsed, list):
                    return [str(drug).strip() for drug in parsed if drug]
            except json.JSONDecodeError:
                return [drug.strip() for drug in labels.split(',') if drug.strip()]
        
        # Handle list/tuple inputs
        if isinstance(labels, (list, tuple)):
            return [str(drug).strip() for drug in labels if drug]
        
        # Handle numpy arrays
        if isinstance(labels, np.ndarray):
            if len(labels) == 0:
                return []
            return [str(drug).strip() for drug in labels.tolist() if drug]
        
        # Handle single values (including pandas scalars)
        try:
            label_str = str(labels).strip()
            return [label_str] if label_str and label_str.lower() not in ['nan', 'none', ''] else []
        except:
            return []
    
    def _extract_diagnoses(self, row: pd.Series) -> List[str]:
        """Extract diagnosis information from a data row."""
        diagnoses = []
        
        # Look for diagnosis-related columns
        diagnosis_columns = ['diagnosis', 'icd9_code', 'icd10_code', 'condition']
        
        for col in diagnosis_columns:
            if col in row and pd.notna(row[col]):
                diagnoses.append(str(row[col]).strip())
        
        # Fallback: use a generic category
        if not diagnoses:
            diagnoses = ['general']
        
        return diagnoses
    
    def _extract_features(self, row: pd.Series) -> List[str]:
        """Extract relevant clinical features from a data row."""
        features = []
        
        # Add diagnosis
        diagnoses = self._extract_diagnoses(row)
        features.extend(diagnoses)
        
        # Add demographic features
        if 'age' in row and pd.notna(row['age']):
            age = float(row['age'])
            if age < 18:
                features.append('age_pediatric')
            elif age < 65:
                features.append('age_adult')
            else:
                features.append('age_elderly')
        
        if 'gender' in row and pd.notna(row['gender']):
            features.append(f"gender_{str(row['gender']).lower()}")
        
        # Add other clinical features if available
        clinical_columns = ['admission_type', 'insurance', 'ethnicity']
        for col in clinical_columns:
            if col in row and pd.notna(row[col]):
                features.append(f"{col}_{str(row[col]).lower().replace(' ', '_')}")
        
        return features
    
    def generate_probability_distribution(self, 
                                        clinical_context: ClinicalContext,
                                        method: str = 'hybrid') -> np.ndarray:
        """
        Generate drug probability distribution P(drug|clinical_context).
        
        Args:
            clinical_context: Patient clinical context
            method: Generation method ('pattern', 'conditional', 'hybrid')
            
        Returns:
            Probability distribution over drug vocabulary
        """
        if not self.drug_vocabulary:
            raise ValueError("Drug vocabulary not initialized. Call analyze_prescription_patterns() first.")
        
        if method == 'pattern':
            return self._generate_pattern_based(clinical_context)
        elif method == 'conditional':
            return self._generate_conditional_based(clinical_context)
        elif method == 'hybrid':
            return self._generate_hybrid(clinical_context)
        else:
            raise ValueError(f"Unknown method: {method}")
    
    def _generate_pattern_based(self, clinical_context: ClinicalContext) -> np.ndarray:
        """Generate probabilities based on diagnosis patterns."""
        probs = np.full(len(self.drug_vocabulary), self.smoothing_alpha)
        
        # Look for matching diagnosis patterns
        context_diagnoses = clinical_context.diagnoses or ['general']
        
        for diagnosis in context_diagnoses:
            if diagnosis in self.diagnosis_drug_patterns:
                pattern = self.diagnosis_drug_patterns[diagnosis]
                
                for drug, prob in pattern.drug_probabilities.items():
                    if drug in self.drug_vocabulary:
                        drug_idx = self.drug_vocabulary.index(drug)
                        probs[drug_idx] += prob * pattern.confidence
        
        # Add baseline frequencies
        for i, drug in enumerate(self.drug_vocabulary):
            if drug in self.baseline_drug_frequencies:
                probs[i] += self.baseline_drug_frequencies[drug] * 0.1
        
        # Normalize
        probs = probs / np.sum(probs)
        return probs
    
    def _generate_conditional_based(self, clinical_context: ClinicalContext) -> np.ndarray:
        """Generate probabilities based on conditional patterns."""
        probs = np.full(len(self.drug_vocabulary), self.smoothing_alpha)
        
        # Extract features from clinical context
        context_features = self._context_to_features(clinical_context)
        
        # Find matching conditional probabilities
        best_match = None
        best_overlap = 0
        
        for cond_prob in self.conditional_probabilities:
            overlap = len(set(context_features) & set(cond_prob.feature_combination))
            if overlap > best_overlap:
                best_overlap = overlap
                best_match = cond_prob
        
        # Apply best matching pattern
        if best_match:
            for drug, prob in best_match.drug_distribution.items():
                if drug in self.drug_vocabulary:
                    drug_idx = self.drug_vocabulary.index(drug)
                    probs[drug_idx] += prob
        
        # Normalize
        probs = probs / np.sum(probs)
        return probs
    
    def _generate_hybrid(self, clinical_context: ClinicalContext) -> np.ndarray:
        """Generate probabilities using hybrid approach."""
        pattern_probs = self._generate_pattern_based(clinical_context)
        conditional_probs = self._generate_conditional_based(clinical_context)
        
        # Weighted combination
        hybrid_probs = 0.6 * pattern_probs + 0.4 * conditional_probs
        
        # Normalize
        hybrid_probs = hybrid_probs / np.sum(hybrid_probs)
        return hybrid_probs
    
    def _context_to_features(self, clinical_context: ClinicalContext) -> List[str]:
        """Convert clinical context to feature list."""
        features = []
        
        # Add diagnoses
        if clinical_context.diagnoses:
            features.extend(clinical_context.diagnoses)
        
        # Add demographic features
        if clinical_context.age is not None:
            if clinical_context.age < 18:
                features.append('age_pediatric')
            elif clinical_context.age < 65:
                features.append('age_adult')
            else:
                features.append('age_elderly')
        
        if clinical_context.gender:
            features.append(f"gender_{clinical_context.gender.lower()}")
        
        # Add other features if available
        if hasattr(clinical_context, 'admission_type') and clinical_context.admission_type:
            features.append(f"admission_{clinical_context.admission_type.lower()}")
        
        return features
    
    def save_patterns(self, filepath: Path) -> None:
        """Save learned patterns to file."""
        patterns_data = {
            'diagnosis_patterns': self.diagnosis_drug_patterns,
            'conditional_probabilities': self.conditional_probabilities,
            'drug_vocabulary': self.drug_vocabulary,
            'baseline_frequencies': self.baseline_drug_frequencies
        }
        
        with open(filepath, 'wb') as f:
            pickle.dump(patterns_data, f)
        
        logger.info(f"Saved prescription patterns to {filepath}")
    
    def load_patterns(self, filepath: Path) -> None:
        """Load learned patterns from file."""
        with open(filepath, 'rb') as f:
            patterns_data = pickle.load(f)
        
        self.diagnosis_drug_patterns = patterns_data['diagnosis_patterns']
        self.conditional_probabilities = patterns_data['conditional_probabilities']
        self.drug_vocabulary = patterns_data['drug_vocabulary']
        self.baseline_drug_frequencies = patterns_data['baseline_frequencies']
        
        logger.info(f"Loaded prescription patterns from {filepath}")


def create_prescription_analytics(data_root: str = "data/processed/mimic", 
                                **kwargs) -> PrescriptionAnalytics:
    """
    Create and initialize prescription analytics.
    
    Args:
        data_root: Path to MIMIC data
        **kwargs: Additional arguments for PrescriptionAnalytics
        
    Returns:
        Initialized PrescriptionAnalytics instance
    """
    config = MimicDataConfig(data_root=data_root)
    mimic_loader = MimicDataLoader(config)
    analytics = PrescriptionAnalytics(mimic_loader, **kwargs)
    
    return analytics
