"""
CNCRC Core Data Structures

This module defines the core data structures for the Conformal Non-Coverage 
Risk Control (CNCRC) framework, specifically designed for safe drug recommendation.
"""
from dataclasses import dataclass, field
from typing import List, Dict, Any, Optional, Set, Union
from enum import Enum
import numpy as np
from scipy import sparse
import logging

logger = logging.getLogger(__name__)


class InteractionSeverity(Enum):
    """Enumeration for drug-drug interaction severity levels."""
    NONE = "none"           # No interaction
    MINOR = "minor"         # Minor interaction (monitoring may be needed)
    MODERATE = "moderate"   # Moderate interaction (dose adjustment may be needed)
    MAJOR = "major"         # Major interaction (combination usually avoided)
    SEVERE = "severe"       # Severe interaction (combination contraindicated)
    
    def to_numeric(self) -> float:
        """Convert severity to numeric value for cost calculation."""
        severity_map = {
            InteractionSeverity.NONE: 0.0,
            InteractionSeverity.MINOR: 0.2,
            InteractionSeverity.MODERATE: 0.5,
            InteractionSeverity.MAJOR: 0.8,
            InteractionSeverity.SEVERE: 1.0
        }
        return severity_map[self]


@dataclass
class DrugInteraction:
    """
    Represents a drug-drug interaction (DDI) with severity information.
    
    This structure stores information about interactions between two drugs,
    including the severity level and associated metadata from DrugBank.
    
    Attributes:
        drug_a: First drug identifier (e.g., DrugBank ID)
        drug_b: Second drug identifier
        severity: Interaction severity level
        description: Human-readable description of the interaction
        mechanism: Mechanism of interaction (optional)
        source: Data source (e.g., "DrugBank")
        confidence: Confidence score for the interaction (0-1)
    """
    drug_a: str
    drug_b: str
    severity: InteractionSeverity
    description: str = ""
    mechanism: Optional[str] = None
    source: str = "DrugBank"
    confidence: float = 1.0
    
    def __post_init__(self):
        """Validate the drug interaction data after initialization."""
        if not isinstance(self.drug_a, str) or not self.drug_a.strip():
            raise ValueError("drug_a must be a non-empty string")
        
        if not isinstance(self.drug_b, str) or not self.drug_b.strip():
            raise ValueError("drug_b must be a non-empty string")
        
        if not isinstance(self.severity, InteractionSeverity):
            raise TypeError("severity must be an InteractionSeverity enum")
        
        if not 0.0 <= self.confidence <= 1.0:
            raise ValueError("confidence must be between 0.0 and 1.0")
    
    def get_cost(self) -> float:
        """
        Calculate the cost associated with this drug interaction.
        
        Returns:
            Numeric cost value based on severity and confidence
        """
        base_cost = self.severity.to_numeric()
        # Adjust cost by confidence (lower confidence = higher uncertainty cost)
        adjusted_cost = base_cost * self.confidence + (1 - self.confidence) * 0.1
        return min(adjusted_cost, 1.0)  # Cap at 1.0
    
    def is_symmetric(self, other: 'DrugInteraction') -> bool:
        """
        Check if this interaction is symmetric with another interaction.
        
        Args:
            other: Another DrugInteraction to compare with
            
        Returns:
            True if the interactions represent the same drug pair
        """
        return ((self.drug_a == other.drug_a and self.drug_b == other.drug_b) or
                (self.drug_a == other.drug_b and self.drug_b == other.drug_a))
    
    def __str__(self) -> str:
        return f"DDI({self.drug_a} ↔ {self.drug_b}, {self.severity.value}, cost={self.get_cost():.2f})"
    
    def __repr__(self) -> str:
        return (f"DrugInteraction(drug_a='{self.drug_a}', drug_b='{self.drug_b}', "
                f"severity={self.severity}, confidence={self.confidence})")


@dataclass
class ClinicalContext:
    """
    Represents patient-specific clinical information from MIMIC-IV.
    
    This structure holds all relevant clinical context needed for safe
    drug recommendation, including demographics, diagnoses, and current medications.
    
    Attributes:
        patient_id: Unique patient identifier
        age: Patient age in years
        gender: Patient gender ('M', 'F', 'O' for other/unknown)
        weight: Patient weight in kg (optional)
        diagnoses: List of ICD diagnosis codes
        current_medications: List of currently prescribed drugs
        allergies: Known drug allergies (drug identifiers)
        admission_type: Type of hospital admission
        icu_stay: Whether patient is in ICU
        comorbidities: List of comorbidity conditions
        lab_values: Recent laboratory values (dict)
        vital_signs: Recent vital signs (dict)
        admission_time: Time of hospital admission
        metadata: Additional clinical metadata
    """
    patient_id: str
    age: int
    gender: str
    diagnoses: List[str] = field(default_factory=list)
    current_medications: List[str] = field(default_factory=list)
    allergies: List[str] = field(default_factory=list)
    weight: Optional[float] = None
    admission_type: str = "UNKNOWN"
    icu_stay: bool = False
    comorbidities: List[str] = field(default_factory=list)
    lab_values: Dict[str, float] = field(default_factory=dict)
    vital_signs: Dict[str, float] = field(default_factory=dict)
    admission_time: Optional[str] = None
    metadata: Dict[str, Any] = field(default_factory=dict)
    
    def __post_init__(self):
        """Validate clinical context data after initialization."""
        if not isinstance(self.patient_id, str) or not self.patient_id.strip():
            raise ValueError("patient_id must be a non-empty string")
        
        if not isinstance(self.age, int) or not 0 <= self.age <= 150:
            raise ValueError("age must be an integer between 0 and 150")
        
        if self.gender not in ['M', 'F', 'O', 'UNKNOWN']:
            raise ValueError("gender must be 'M', 'F', 'O', or 'UNKNOWN'")
        
        if self.weight is not None and (not isinstance(self.weight, (int, float)) or self.weight <= 0):
            raise ValueError("weight must be a positive number")
        
        # Validate diagnosis codes (basic ICD format check)
        for diagnosis in self.diagnoses:
            if not isinstance(diagnosis, str) or len(diagnosis) < 3:
                raise ValueError(f"Invalid diagnosis code: {diagnosis}")
    
    def add_diagnosis(self, diagnosis_code: str) -> None:
        """Add a new diagnosis to the patient's record."""
        if diagnosis_code and diagnosis_code not in self.diagnoses:
            self.diagnoses.append(diagnosis_code)
    
    def add_medication(self, drug_id: str) -> None:
        """Add a medication to current medications list."""
        if drug_id and drug_id not in self.current_medications:
            self.current_medications.append(drug_id)
    
    def add_allergy(self, drug_id: str) -> None:
        """Add a drug allergy to the patient's record."""
        if drug_id and drug_id not in self.allergies:
            self.allergies.append(drug_id)
    
    def has_allergy(self, drug_id: str) -> bool:
        """Check if patient is allergic to a specific drug."""
        return drug_id in self.allergies
    
    def is_contraindicated(self, drug_id: str) -> bool:
        """
        Check if a drug is contraindicated for this patient.
        Currently only checks allergies, can be extended for other contraindications.
        """
        return self.has_allergy(drug_id)
    
    def get_risk_factors(self) -> Dict[str, Any]:
        """
        Extract relevant risk factors for drug recommendation.
        
        Returns:
            Dictionary of risk factors that may affect drug safety
        """
        risk_factors = {
            'elderly': self.age >= 65,
            'pediatric': self.age < 18,
            'icu_patient': self.icu_stay,
            'polypharmacy': len(self.current_medications) >= 5,
            'has_allergies': len(self.allergies) > 0,
            'multiple_diagnoses': len(self.diagnoses) >= 3,
            'comorbidity_count': len(self.comorbidities)
        }
        
        # Add specific high-risk conditions
        high_risk_conditions = ['I50', 'N18', 'K72']  # Heart failure, CKD, Liver failure
        risk_factors['high_risk_conditions'] = any(
            any(diag.startswith(condition) for diag in self.diagnoses)
            for condition in high_risk_conditions
        )
        
        return risk_factors
    
    def get_summary(self) -> str:
        """Generate a concise summary of the clinical context."""
        summary_parts = [
            f"Patient {self.patient_id}",
            f"{self.age}y {self.gender}",
            f"{len(self.diagnoses)} diagnoses",
            f"{len(self.current_medications)} medications"
        ]
        
        if self.icu_stay:
            summary_parts.append("ICU")
        
        if self.allergies:
            summary_parts.append(f"{len(self.allergies)} allergies")
        
        return " | ".join(summary_parts)
    
    def __str__(self) -> str:
        return self.get_summary()
    
    def __repr__(self) -> str:
        return (f"ClinicalContext(patient_id='{self.patient_id}', age={self.age}, "
                f"gender='{self.gender}', diagnoses={len(self.diagnoses)}, "
                f"medications={len(self.current_medications)})")


@dataclass
class CostMatrix:
    """
    Wrapper for drug-drug interaction cost matrix.
    
    This class provides a convenient interface for storing and querying
    costs between drug pairs. Supports both dense (NumPy) and sparse 
    (SciPy) matrix representations for efficiency.
    
    Attributes:
        drug_ids: List of drug identifiers (index mapping)
        matrix: Cost matrix (NumPy array or SciPy sparse matrix)
        is_sparse: Whether the underlying matrix is sparse
        default_cost: Default cost for unknown drug pairs
        symmetric: Whether the matrix should be symmetric
        metadata: Additional metadata about the cost matrix
    """
    drug_ids: List[str]
    matrix: Union[np.ndarray, sparse.spmatrix]
    is_sparse: bool = False
    default_cost: float = 0.0
    symmetric: bool = True
    metadata: Dict[str, Any] = field(default_factory=dict)
    
    def __post_init__(self):
        """Validate cost matrix after initialization."""
        if not self.drug_ids:
            raise ValueError("drug_ids cannot be empty")
        
        if len(set(self.drug_ids)) != len(self.drug_ids):
            raise ValueError("drug_ids must be unique")
        
        n_drugs = len(self.drug_ids)
        
        # Validate matrix dimensions
        if hasattr(self.matrix, 'shape'):
            if self.matrix.shape != (n_drugs, n_drugs):
                raise ValueError(f"Matrix shape {self.matrix.shape} doesn't match drug count {n_drugs}")
        else:
            raise TypeError("matrix must be a NumPy array or SciPy sparse matrix")
        
        # Check if matrix is actually sparse
        if sparse.issparse(self.matrix):
            self.is_sparse = True
        elif isinstance(self.matrix, np.ndarray):
            self.is_sparse = False
        else:
            raise TypeError("matrix must be NumPy array or SciPy sparse matrix")
        
        # Validate cost range
        if self.is_sparse:
            nonzero_costs = self.matrix.data
        else:
            nonzero_costs = self.matrix.flatten()
        
        if len(nonzero_costs) > 0 and (np.any(nonzero_costs < 0) or np.any(nonzero_costs > 1)):
            logger.warning("Cost matrix contains values outside [0,1] range")
    
    @classmethod
    def from_interactions(
        cls, 
        drug_ids: List[str], 
        interactions: List[DrugInteraction],
        use_sparse: bool = True,
        default_cost: float = 0.0
    ) -> 'CostMatrix':
        """
        Create a cost matrix from a list of drug interactions.
        
        Args:
            drug_ids: List of all drug identifiers
            interactions: List of DrugInteraction objects
            use_sparse: Whether to use sparse matrix representation
            default_cost: Default cost for non-interacting drug pairs
            
        Returns:
            CostMatrix instance
        """
        n_drugs = len(drug_ids)
        drug_to_idx = {drug_id: idx for idx, drug_id in enumerate(drug_ids)}
        
        if use_sparse:
            # Create sparse matrix
            matrix = sparse.lil_matrix((n_drugs, n_drugs))
            matrix.setdiag(default_cost)  # Diagonal (self-interactions)
        else:
            # Create dense matrix
            matrix = np.full((n_drugs, n_drugs), default_cost, dtype=np.float32)
        
        # Fill in interaction costs
        for interaction in interactions:
            try:
                i = drug_to_idx[interaction.drug_a]
                j = drug_to_idx[interaction.drug_b]
                cost = interaction.get_cost()
                
                matrix[i, j] = cost
                if i != j:  # Symmetric, unless self-interaction
                    matrix[j, i] = cost
                    
            except KeyError as e:
                logger.warning(f"Drug {e} in interaction not found in drug_ids")
        
        if use_sparse:
            matrix = matrix.tocsr()  # Convert to efficient format
        
        return cls(
            drug_ids=drug_ids,
            matrix=matrix,
            is_sparse=use_sparse,
            default_cost=default_cost,
            symmetric=True
        )
    
    @classmethod
    def create_empty(
        cls,
        drug_ids: List[str],
        use_sparse: bool = True,
        default_cost: float = 0.0
    ) -> 'CostMatrix':
        """Create an empty cost matrix with specified dimensions."""
        n_drugs = len(drug_ids)
        
        if use_sparse:
            matrix = sparse.csr_matrix((n_drugs, n_drugs))
        else:
            matrix = np.full((n_drugs, n_drugs), default_cost, dtype=np.float32)
        
        return cls(
            drug_ids=drug_ids,
            matrix=matrix,
            is_sparse=use_sparse,
            default_cost=default_cost
        )
    
    def get_cost(self, drug_a: str, drug_b: str) -> float:
        """
        Get the cost between two drugs.
        
        Args:
            drug_a: First drug identifier
            drug_b: Second drug identifier
            
        Returns:
            Cost value between the two drugs
        """
        try:
            i = self.drug_ids.index(drug_a)
            j = self.drug_ids.index(drug_b)
            return float(self.matrix[i, j])
        except ValueError:
            # Drug not found in matrix
            return self.default_cost
    
    def set_cost(self, drug_a: str, drug_b: str, cost: float) -> None:
        """
        Set the cost between two drugs.
        
        Args:
            drug_a: First drug identifier
            drug_b: Second drug identifier
            cost: Cost value to set
        """
        if not 0 <= cost <= 1:
            raise ValueError("Cost must be between 0 and 1")
        
        try:
            i = self.drug_ids.index(drug_a)
            j = self.drug_ids.index(drug_b)
            
            self.matrix[i, j] = cost
            if self.symmetric and i != j:
                self.matrix[j, i] = cost
                
        except ValueError:
            raise ValueError(f"One or both drugs not found: {drug_a}, {drug_b}")
    
    def get_row_costs(self, drug_id: str) -> np.ndarray:
        """
        Get all costs for a specific drug (one row of the matrix).
        
        Args:
            drug_id: Drug identifier
            
        Returns:
            Array of costs for this drug with all other drugs
        """
        try:
            i = self.drug_ids.index(drug_id)
            if self.is_sparse:
                return self.matrix[i, :].toarray().flatten()
            else:
                return self.matrix[i, :].copy()
        except ValueError:
            # Drug not found, return default costs
            return np.full(len(self.drug_ids), self.default_cost)
    
    def get_max_cost_for_drug(self, drug_id: str) -> float:
        """
        Get the maximum cost for a drug with any other drug.
        
        This is useful for the CNCRC risk calculation where we need
        max_{j≠y} Cost(y, j).
        """
        row_costs = self.get_row_costs(drug_id)
        
        # Exclude self-interaction if present
        try:
            i = self.drug_ids.index(drug_id)
            mask = np.arange(len(row_costs)) != i
            if np.any(mask):
                return float(np.max(row_costs[mask]))
            else:
                return self.default_cost
        except ValueError:
            return float(np.max(row_costs))
    
    def add_drug(self, drug_id: str, initial_costs: Optional[List[float]] = None) -> None:
        """
        Add a new drug to the cost matrix.
        
        Args:
            drug_id: New drug identifier
            initial_costs: Optional list of costs with existing drugs
        """
        if drug_id in self.drug_ids:
            raise ValueError(f"Drug {drug_id} already exists in matrix")
        
        n_existing = len(self.drug_ids)
        n_new = n_existing + 1
        
        if initial_costs is None:
            initial_costs = [self.default_cost] * n_existing
        elif len(initial_costs) != n_existing:
            raise ValueError(f"initial_costs length {len(initial_costs)} != existing drugs {n_existing}")
        
        # Expand matrix
        if self.is_sparse:
            # For sparse matrices, we need to recreate
            old_matrix = self.matrix.tocoo()
            new_matrix = sparse.lil_matrix((n_new, n_new))
            
            # Copy existing data
            new_matrix[:n_existing, :n_existing] = self.matrix
            
            # Add new row and column
            for j, cost in enumerate(initial_costs):
                new_matrix[n_existing, j] = cost
                if self.symmetric:
                    new_matrix[j, n_existing] = cost
            
            # Self-interaction
            new_matrix[n_existing, n_existing] = self.default_cost
            
            self.matrix = new_matrix.tocsr()
        else:
            # For dense matrices, use numpy resize
            new_matrix = np.full((n_new, n_new), self.default_cost, dtype=self.matrix.dtype)
            new_matrix[:n_existing, :n_existing] = self.matrix
            
            # Add new costs
            for j, cost in enumerate(initial_costs):
                new_matrix[n_existing, j] = cost
                if self.symmetric:
                    new_matrix[j, n_existing] = cost
            
            self.matrix = new_matrix
        
        self.drug_ids.append(drug_id)
    
    def to_dense(self) -> np.ndarray:
        """Convert to dense NumPy array."""
        if self.is_sparse:
            return self.matrix.toarray()
        else:
            return self.matrix.copy()
    
    def to_sparse(self) -> sparse.csr_matrix:
        """Convert to sparse CSR matrix."""
        if self.is_sparse:
            return self.matrix.copy()
        else:
            return sparse.csr_matrix(self.matrix)
    
    def get_stats(self) -> Dict[str, Any]:
        """Get statistics about the cost matrix."""
        matrix_size = self.matrix.size
        
        if self.is_sparse:
            data = self.matrix.data
            sparsity = 1.0 - (len(data) / matrix_size) if matrix_size > 0 else 0.0
            nonzero_count = len(data)
        else:
            data = self.matrix.flatten()
            nonzero_count = np.count_nonzero(data)
            sparsity = 1.0 - (nonzero_count / matrix_size) if matrix_size > 0 else 0.0
        
        stats = {
            'n_drugs': len(self.drug_ids),
            'matrix_shape': self.matrix.shape,
            'is_sparse': self.is_sparse,
            'sparsity': sparsity,
            'nonzero_elements': nonzero_count,
            'min_cost': float(np.min(data)) if len(data) > 0 else self.default_cost,
            'max_cost': float(np.max(data)) if len(data) > 0 else self.default_cost,
            'mean_cost': float(np.mean(data)) if len(data) > 0 else self.default_cost,
            'default_cost': self.default_cost
        }
        
        return stats
    
    def __str__(self) -> str:
        stats = self.get_stats()
        return (f"CostMatrix({stats['n_drugs']} drugs, "
                f"{'sparse' if self.is_sparse else 'dense'}, "
                f"sparsity={stats['sparsity']:.2f})")
    
    def __repr__(self) -> str:
        return (f"CostMatrix(n_drugs={len(self.drug_ids)}, "
                f"is_sparse={self.is_sparse}, "
                f"default_cost={self.default_cost})")


@dataclass
class PredictionSet:
    """
    Represents a CNCRC prediction set with recommended drugs and risk metrics.
    
    This structure stores the output of the CNCRC algorithm: a set of recommended
    drugs that satisfy the risk threshold, along with associated risk scores
    and metadata about the prediction process.
    
    Attributes:
        candidates: List of drug identifiers in the prediction set
        risk_scores: Risk scores for each candidate (s(x,y) values)
        probabilities: Model probabilities for each candidate (optional)
        clinical_context: Associated clinical context
        threshold: Risk threshold used (q value)
        alpha: Risk level parameter
        timestamp: When the prediction was made
        model_info: Information about the underlying model
        metadata: Additional prediction metadata
    """
    candidates: List[str]
    risk_scores: Dict[str, float] = field(default_factory=dict)
    probabilities: Dict[str, float] = field(default_factory=dict)
    clinical_context: Optional[ClinicalContext] = None
    threshold: Optional[float] = None
    alpha: Optional[float] = None
    timestamp: Optional[str] = None
    model_info: Dict[str, Any] = field(default_factory=dict)
    metadata: Dict[str, Any] = field(default_factory=dict)
    
    def __post_init__(self):
        """Validate prediction set after initialization."""
        if not isinstance(self.candidates, list):
            raise TypeError("candidates must be a list")
        
        # Remove duplicates while preserving order
        seen = set()
        unique_candidates = []
        for candidate in self.candidates:
            if candidate not in seen:
                unique_candidates.append(candidate)
                seen.add(candidate)
        self.candidates = unique_candidates
        
        # Validate risk scores
        for drug_id, score in self.risk_scores.items():
            if not isinstance(score, (int, float)) or score < 0:
                raise ValueError(f"Risk score for {drug_id} must be non-negative: {score}")
        
        # Validate probabilities
        for drug_id, prob in self.probabilities.items():
            if not isinstance(prob, (int, float)) or not 0 <= prob <= 1:
                raise ValueError(f"Probability for {drug_id} must be in [0,1]: {prob}")
        
        # Validate threshold and alpha
        if self.threshold is not None and self.threshold < 0:
            raise ValueError("Threshold must be non-negative")
        
        if self.alpha is not None and not 0 <= self.alpha <= 1:
            raise ValueError("Alpha must be in [0,1]")
    
    def add_candidate(
        self, 
        drug_id: str, 
        risk_score: Optional[float] = None,
        probability: Optional[float] = None
    ) -> None:
        """Add a new candidate to the prediction set."""
        if drug_id not in self.candidates:
            self.candidates.append(drug_id)
        
        if risk_score is not None:
            self.risk_scores[drug_id] = risk_score
        
        if probability is not None:
            self.probabilities[drug_id] = probability
    
    def remove_candidate(self, drug_id: str) -> None:
        """Remove a candidate from the prediction set."""
        if drug_id in self.candidates:
            self.candidates.remove(drug_id)
        
        self.risk_scores.pop(drug_id, None)
        self.probabilities.pop(drug_id, None)
    
    def filter_by_threshold(self, threshold: float) -> 'PredictionSet':
        """
        Create a new prediction set filtering candidates by risk threshold.
        
        Args:
            threshold: Risk threshold for filtering
            
        Returns:
            New PredictionSet with filtered candidates
        """
        filtered_candidates = []
        filtered_scores = {}
        filtered_probs = {}
        
        for drug_id in self.candidates:
            risk_score = self.risk_scores.get(drug_id, 0.0)
            if risk_score <= threshold:
                filtered_candidates.append(drug_id)
                filtered_scores[drug_id] = risk_score
                if drug_id in self.probabilities:
                    filtered_probs[drug_id] = self.probabilities[drug_id]
        
        return PredictionSet(
            candidates=filtered_candidates,
            risk_scores=filtered_scores,
            probabilities=filtered_probs,
            clinical_context=self.clinical_context,
            threshold=threshold,
            alpha=self.alpha,
            timestamp=self.timestamp,
            model_info=self.model_info.copy(),
            metadata=self.metadata.copy()
        )
    
    def get_sorted_candidates(
        self, 
        sort_by: str = "risk_score",
        ascending: bool = True
    ) -> List[str]:
        """
        Get candidates sorted by specified criterion.
        
        Args:
            sort_by: Sorting criterion ("risk_score", "probability", "drug_id")
            ascending: Whether to sort in ascending order
            
        Returns:
            List of sorted drug identifiers
        """
        if sort_by == "risk_score":
            return sorted(
                self.candidates,
                key=lambda x: self.risk_scores.get(x, float('inf')),
                reverse=not ascending
            )
        elif sort_by == "probability":
            return sorted(
                self.candidates,
                key=lambda x: self.probabilities.get(x, 0.0),
                reverse=not ascending
            )
        elif sort_by == "drug_id":
            return sorted(self.candidates, reverse=not ascending)
        else:
            raise ValueError(f"Unknown sort criterion: {sort_by}")
    
    def get_top_candidates(self, n: int, sort_by: str = "probability") -> List[str]:
        """
        Get top N candidates based on specified criterion.
        
        Args:
            n: Number of candidates to return
            sort_by: Sorting criterion
            
        Returns:
            List of top N drug identifiers
        """
        sorted_candidates = self.get_sorted_candidates(sort_by, ascending=False)
        return sorted_candidates[:n]
    
    def get_risk_summary(self) -> Dict[str, Any]:
        """
        Get summary statistics about risks in the prediction set.
        
        Returns:
            Dictionary with risk statistics
        """
        if not self.risk_scores:
            return {
                'n_candidates': len(self.candidates),
                'min_risk': None,
                'max_risk': None,
                'mean_risk': None,
                'std_risk': None
            }
        
        risk_values = list(self.risk_scores.values())
        
        return {
            'n_candidates': len(self.candidates),
            'min_risk': float(np.min(risk_values)),
            'max_risk': float(np.max(risk_values)),
            'mean_risk': float(np.mean(risk_values)),
            'std_risk': float(np.std(risk_values)),
            'median_risk': float(np.median(risk_values))
        }
    
    def check_contraindications(self) -> List[str]:
        """
        Check for contraindicated drugs based on clinical context.
        
        Returns:
            List of contraindicated drug identifiers
        """
        if self.clinical_context is None:
            return []
        
        contraindicated = []
        for drug_id in self.candidates:
            if self.clinical_context.is_contraindicated(drug_id):
                contraindicated.append(drug_id)
        
        return contraindicated
    
    def remove_contraindications(self) -> 'PredictionSet':
        """
        Create a new prediction set with contraindicated drugs removed.
        
        Returns:
            New PredictionSet without contraindicated drugs
        """
        contraindicated = self.check_contraindications()
        
        safe_candidates = [drug for drug in self.candidates if drug not in contraindicated]
        safe_scores = {drug: score for drug, score in self.risk_scores.items() 
                      if drug not in contraindicated}
        safe_probs = {drug: prob for drug, prob in self.probabilities.items() 
                     if drug not in contraindicated}
        
        new_metadata = self.metadata.copy()
        new_metadata['removed_contraindications'] = contraindicated
        
        return PredictionSet(
            candidates=safe_candidates,
            risk_scores=safe_scores,
            probabilities=safe_probs,
            clinical_context=self.clinical_context,
            threshold=self.threshold,
            alpha=self.alpha,
            timestamp=self.timestamp,
            model_info=self.model_info.copy(),
            metadata=new_metadata
        )
    
    def is_empty(self) -> bool:
        """Check if the prediction set is empty."""
        return len(self.candidates) == 0
    
    def size(self) -> int:
        """Get the size of the prediction set."""
        return len(self.candidates)
    
    def contains(self, drug_id: str) -> bool:
        """Check if a drug is in the prediction set."""
        return drug_id in self.candidates
    
    def get_coverage_info(self, true_drug: str) -> Dict[str, Any]:
        """
        Get coverage information for a true drug label.
        
        Args:
            true_drug: The true drug identifier
            
        Returns:
            Dictionary with coverage information
        """
        is_covered = self.contains(true_drug)
        
        coverage_info = {
            'is_covered': is_covered,
            'true_drug': true_drug,
            'set_size': self.size(),
            'threshold': self.threshold,
            'alpha': self.alpha
        }
        
        if is_covered and true_drug in self.risk_scores:
            coverage_info['true_drug_risk_score'] = self.risk_scores[true_drug]
        
        if is_covered and true_drug in self.probabilities:
            coverage_info['true_drug_probability'] = self.probabilities[true_drug]
        
        return coverage_info
    
    def to_dict(self) -> Dict[str, Any]:
        """Convert prediction set to dictionary format."""
        return {
            'candidates': self.candidates,
            'risk_scores': self.risk_scores,
            'probabilities': self.probabilities,
            'threshold': self.threshold,
            'alpha': self.alpha,
            'timestamp': self.timestamp,
            'model_info': self.model_info,
            'metadata': self.metadata,
            'risk_summary': self.get_risk_summary(),
            'size': self.size()
        }
    
    @classmethod
    def from_dict(cls, data: Dict[str, Any], clinical_context: Optional[ClinicalContext] = None) -> 'PredictionSet':
        """Create PredictionSet from dictionary format."""
        return cls(
            candidates=data.get('candidates', []),
            risk_scores=data.get('risk_scores', {}),
            probabilities=data.get('probabilities', {}),
            clinical_context=clinical_context,
            threshold=data.get('threshold'),
            alpha=data.get('alpha'),
            timestamp=data.get('timestamp'),
            model_info=data.get('model_info', {}),
            metadata=data.get('metadata', {})
        )
    
    def __len__(self) -> int:
        return len(self.candidates)
    
    def __iter__(self):
        return iter(self.candidates)
    
    def __contains__(self, drug_id: str) -> bool:
        return drug_id in self.candidates
    
    def __str__(self) -> str:
        risk_summary = self.get_risk_summary()
        parts = [f"PredictionSet({self.size()} candidates)"]
        
        if self.threshold is not None:
            parts.append(f"threshold={self.threshold:.3f}")
        
        if risk_summary['mean_risk'] is not None:
            parts.append(f"mean_risk={risk_summary['mean_risk']:.3f}")
        
        return " | ".join(parts)
    
    def __repr__(self) -> str:
        return (f"PredictionSet(candidates={len(self.candidates)}, "
                f"threshold={self.threshold}, alpha={self.alpha})")
