"""
Conformal Non-Coverage Risk Control (CNCRC) - Prediction Set Construction

This module implements the core prediction set construction algorithm for CNCRC,
building conformal prediction sets based on calibrated risk thresholds.

Key Components:
- build_prediction_set: Constructs prediction sets C(x) = {y: s(x,y) ≤ q}
- build_prediction_set_detailed: Enhanced version with risk scores and metadata
- build_prediction_set_batch: Batch processing for multiple inputs
- validate_prediction_set: Empirical validation of set properties

Theoretical Foundation:
The prediction set C(x) contains all labels y whose risk-weighted non-conformity
score s(x,y) = max_{j≠y} P(y_j|x) · Cost(y,y_j) is below the calibrated threshold q.

By conformal prediction theory, this guarantees that with probability at least 1-α,
the true label is contained in the prediction set: P(y* ∈ C(X)) ≥ 1-α.
"""

import numpy as np
import warnings
from typing import List, Union, Dict, Tuple, Optional, Any
from dataclasses import dataclass, field

from .risk_weighted_score import calculate_risk_weighted_score, calculate_risk_weighted_scores_batch
from .data_structures import PredictionSet, ClinicalContext, CostMatrix


def build_prediction_set(
    probabilities: Union[np.ndarray, List[float]], 
    cost_matrix: Union[np.ndarray, CostMatrix],
    q_threshold: float
) -> List[int]:
    """
    Construct a CNCRC prediction set C(x) = {y: s(x,y) ≤ q}.
    
    For each possible label/candidate y, calculates the risk-weighted non-conformity
    score s(x,y) and includes y in the prediction set if s(x,y) ≤ q_threshold.
    
    Args:
        probabilities: Model output probabilities P(y_j|x), shape (n_classes,)
        cost_matrix: Cost matrix Cost(y,y_j), shape (n_classes, n_classes) 
                    or CostMatrix object
        q_threshold: Calibrated quantile threshold from conformal calibration
        
    Returns:
        List of candidate indices (integers) that form the prediction set
        
    Raises:
        ValueError: If inputs have invalid shapes or values
        TypeError: If inputs have incompatible types
        
    Example:
        >>> import numpy as np
        >>> probs = np.array([0.1, 0.6, 0.2, 0.1])  # 4 drugs
        >>> costs = np.array([[0.0, 0.5, 0.8, 0.3],
        ...                   [0.5, 0.0, 0.6, 0.4], 
        ...                   [0.8, 0.6, 0.0, 0.7],
        ...                   [0.3, 0.4, 0.7, 0.0]])
        >>> q = 0.25  # Calibrated threshold
        >>> prediction_set = build_prediction_set(probs, costs, q)
        >>> print(f"Prediction set: {prediction_set}")
        
    Note:
        This is the core CNCRC algorithm. The prediction set provides
        distribution-free coverage guarantees under exchangeability.
    """
    # Input validation
    if not isinstance(probabilities, np.ndarray):
        probabilities = np.asarray(probabilities, dtype=np.float64)
    
    if probabilities.ndim != 1:
        raise ValueError(f"probabilities must be 1D, got shape {probabilities.shape}")
    
    n_classes = len(probabilities)
    
    # Handle CostMatrix object
    if isinstance(cost_matrix, CostMatrix):
        cost_array = cost_matrix.matrix
    else:
        cost_array = np.asarray(cost_matrix, dtype=np.float64)
    
    if cost_array.shape != (n_classes, n_classes):
        raise ValueError(
            f"cost_matrix shape {cost_array.shape} incompatible with "
            f"probabilities length {n_classes}"
        )
    
    if not isinstance(q_threshold, (int, float)):
        raise TypeError(f"q_threshold must be numeric, got {type(q_threshold)}")
    
    if q_threshold < 0:
        raise ValueError(f"q_threshold must be non-negative, got {q_threshold}")
    
    # Core algorithm: Test each candidate y
    prediction_set = []
    
    for y_candidate in range(n_classes):
        try:
            # Calculate risk-weighted score s(x, y_candidate)
            score = calculate_risk_weighted_score(
                probabilities=probabilities,
                cost_matrix=cost_array,
                y_true=y_candidate
            )
            
            # Include candidate if score ≤ threshold
            if score <= q_threshold:
                prediction_set.append(y_candidate)
                
        except Exception as e:
            warnings.warn(
                f"Failed to calculate score for candidate {y_candidate}: {e}",
                RuntimeWarning
            )
            continue
    
    return prediction_set


def build_prediction_set_detailed(
    probabilities: Union[np.ndarray, List[float]],
    cost_matrix: Union[np.ndarray, CostMatrix],
    q_threshold: float,
    clinical_context: Optional[ClinicalContext] = None,
    alpha: Optional[float] = None,
    drug_mapping: Optional[Dict[int, str]] = None,
    include_all_scores: bool = False
) -> PredictionSet:
    """
    Build a detailed CNCRC prediction set with metadata and risk scores.
    
    This enhanced version returns a full PredictionSet object containing
    the prediction set, risk scores, probabilities, and metadata.
    
    Args:
        probabilities: Model output probabilities P(y_j|x)
        cost_matrix: Cost matrix or CostMatrix object
        q_threshold: Calibrated quantile threshold  
        clinical_context: Patient clinical context (optional)
        alpha: Miscoverage level used in calibration (optional)
        drug_mapping: Mapping from indices to drug identifiers (optional)
        include_all_scores: If True, include risk scores for all candidates
        
    Returns:
        PredictionSet object with candidates, scores, and metadata
        
    Example:
        >>> prediction_set = build_prediction_set_detailed(
        ...     probabilities=probs,
        ...     cost_matrix=costs, 
        ...     q_threshold=0.25,
        ...     alpha=0.1,
        ...     drug_mapping={0: "drug_A", 1: "drug_B", 2: "drug_C", 3: "drug_D"}
        ... )
        >>> print(f"Recommended drugs: {prediction_set.candidates}")
        >>> print(f"Risk scores: {prediction_set.risk_scores}")
    """
    # Get basic prediction set
    prediction_indices = build_prediction_set(probabilities, cost_matrix, q_threshold)
    
    # Convert to numpy for consistency
    if not isinstance(probabilities, np.ndarray):
        probabilities = np.asarray(probabilities, dtype=np.float64)
    
    # Handle CostMatrix object
    if isinstance(cost_matrix, CostMatrix):
        cost_array = cost_matrix.matrix
    else:
        cost_array = np.asarray(cost_matrix, dtype=np.float64)
    
    # Calculate risk scores for all or just prediction set candidates
    risk_scores = {}
    probabilities_dict = {}
    
    candidates_to_score = range(len(probabilities)) if include_all_scores else prediction_indices
    
    for y_candidate in candidates_to_score:
        try:
            score = calculate_risk_weighted_score(
                probabilities=probabilities,
                cost_matrix=cost_array,
                y_true=y_candidate
            )
            
            # Use drug mapping if provided, otherwise use indices
            drug_id = drug_mapping.get(y_candidate, str(y_candidate)) if drug_mapping else str(y_candidate)
            
            risk_scores[drug_id] = score
            probabilities_dict[drug_id] = float(probabilities[y_candidate])
            
        except Exception as e:
            warnings.warn(f"Failed to score candidate {y_candidate}: {e}", RuntimeWarning)
            continue
    
    # Convert prediction set indices to drug IDs
    if drug_mapping:
        candidates = [drug_mapping.get(idx, str(idx)) for idx in prediction_indices]
    else:
        candidates = [str(idx) for idx in prediction_indices]
    
    # Create PredictionSet object
    return PredictionSet(
        candidates=candidates,
        risk_scores={k: v for k, v in risk_scores.items() if k in candidates},
        probabilities={k: v for k, v in probabilities_dict.items() if k in candidates},
        clinical_context=clinical_context,
        threshold=q_threshold,
        alpha=alpha,
        metadata={
            "algorithm": "CNCRC",
            "n_candidates_evaluated": len(probabilities),
            "prediction_set_size": len(prediction_indices),
            "all_risk_scores": risk_scores if include_all_scores else None,
            "all_probabilities": dict(enumerate(probabilities)) if include_all_scores else None
        }
    )


def build_prediction_set_batch(
    probabilities_batch: np.ndarray,
    cost_matrix: Union[np.ndarray, CostMatrix],
    q_threshold: float,
    drug_mapping: Optional[Dict[int, str]] = None
) -> List[List[int]]:
    """
    Construct CNCRC prediction sets for multiple inputs in batch.
    
    Efficiently processes multiple probability vectors using the same
    cost matrix and threshold, suitable for evaluation workflows.
    
    Args:
        probabilities_batch: Batch of probability vectors, shape (batch_size, n_classes)
        cost_matrix: Cost matrix, shape (n_classes, n_classes)
        q_threshold: Calibrated quantile threshold
        drug_mapping: Optional mapping from indices to drug names
        
    Returns:
        List of prediction sets, one per input
        
    Example:
        >>> batch_probs = np.array([[0.1, 0.6, 0.2, 0.1],
        ...                         [0.4, 0.1, 0.3, 0.2]])
        >>> prediction_sets = build_prediction_set_batch(batch_probs, costs, 0.25)
        >>> print(f"Batch prediction sets: {prediction_sets}")
    """
    if not isinstance(probabilities_batch, np.ndarray):
        probabilities_batch = np.asarray(probabilities_batch, dtype=np.float64)
    
    if probabilities_batch.ndim != 2:
        raise ValueError(f"probabilities_batch must be 2D, got shape {probabilities_batch.shape}")
    
    batch_size, n_classes = probabilities_batch.shape
    
    # Handle CostMatrix object
    if isinstance(cost_matrix, CostMatrix):
        cost_array = cost_matrix.matrix
    else:
        cost_array = np.asarray(cost_matrix, dtype=np.float64)
    
    if cost_array.shape != (n_classes, n_classes):
        raise ValueError(f"cost_matrix shape incompatible with probabilities")
    
    # Process each input in the batch
    prediction_sets = []
    
    for i in range(batch_size):
        pred_set = build_prediction_set(
            probabilities=probabilities_batch[i],
            cost_matrix=cost_array,
            q_threshold=q_threshold
        )
        prediction_sets.append(pred_set)
    
    return prediction_sets


def validate_prediction_set(
    prediction_set: List[int],
    probabilities: np.ndarray,
    cost_matrix: Union[np.ndarray, CostMatrix],
    q_threshold: float,
    tolerance: float = 1e-10
) -> Dict[str, Any]:
    """
    Validate that a prediction set satisfies CNCRC construction rules.
    
    Checks that all included candidates have scores ≤ q_threshold and
    all excluded candidates have scores > q_threshold.
    
    Args:
        prediction_set: List of candidate indices in the prediction set
        probabilities: Model output probabilities
        cost_matrix: Cost matrix 
        q_threshold: Threshold used for construction
        tolerance: Numerical tolerance for score comparisons
        
    Returns:
        Dictionary with validation results and statistics
        
    Example:
        >>> validation = validate_prediction_set(
        ...     prediction_set=[0, 1],
        ...     probabilities=probs,
        ...     cost_matrix=costs,
        ...     q_threshold=0.25
        ... )
        >>> print(f"Valid: {validation['is_valid']}")
    """
    if not isinstance(probabilities, np.ndarray):
        probabilities = np.asarray(probabilities, dtype=np.float64)
    
    n_classes = len(probabilities)
    
    # Handle CostMatrix object
    if isinstance(cost_matrix, CostMatrix):
        cost_array = cost_matrix.matrix
    else:
        cost_array = np.asarray(cost_matrix, dtype=np.float64)
    
    # Calculate all risk scores
    all_scores = {}
    for y_candidate in range(n_classes):
        try:
            score = calculate_risk_weighted_score(
                probabilities=probabilities,
                cost_matrix=cost_array,
                y_true=y_candidate
            )
            all_scores[y_candidate] = score
        except Exception as e:
            return {
                "is_valid": False,
                "error": f"Failed to calculate score for candidate {y_candidate}: {e}"
            }
    
    # Validation checks
    violations = []
    
    # Check included candidates
    for candidate in prediction_set:
        if candidate not in all_scores:
            violations.append(f"Candidate {candidate} not in score calculations")
        elif all_scores[candidate] > q_threshold + tolerance:
            violations.append(
                f"Included candidate {candidate} has score {all_scores[candidate]:.6f} > "
                f"threshold {q_threshold:.6f}"
            )
    
    # Check excluded candidates
    excluded_candidates = [i for i in range(n_classes) if i not in prediction_set]
    for candidate in excluded_candidates:
        if all_scores[candidate] <= q_threshold - tolerance:
            violations.append(
                f"Excluded candidate {candidate} has score {all_scores[candidate]:.6f} ≤ "
                f"threshold {q_threshold:.6f}"
            )
    
    # Compile results
    included_scores = [all_scores[c] for c in prediction_set if c in all_scores]
    excluded_scores = [all_scores[c] for c in excluded_candidates if c in all_scores]
    
    return {
        "is_valid": len(violations) == 0,
        "violations": violations,
        "prediction_set_size": len(prediction_set),
        "total_candidates": n_classes,
        "included_scores": included_scores,
        "excluded_scores": excluded_scores,
        "max_included_score": max(included_scores) if included_scores else None,
        "min_excluded_score": min(excluded_scores) if excluded_scores else None,
        "threshold": q_threshold,
        "all_scores": all_scores
    }


# Utility functions for analysis and debugging

def analyze_prediction_set_sensitivity(
    probabilities: np.ndarray,
    cost_matrix: Union[np.ndarray, CostMatrix],
    q_thresholds: List[float]
) -> Dict[str, Any]:
    """
    Analyze how prediction set size varies with threshold q.
    
    Useful for understanding the trade-off between coverage and set size.
    
    Args:
        probabilities: Model output probabilities
        cost_matrix: Cost matrix
        q_thresholds: List of threshold values to test
        
    Returns:
        Dictionary with sensitivity analysis results
    """
    results = {
        "thresholds": q_thresholds,
        "set_sizes": [],
        "prediction_sets": [],
        "coverage_efficiency": []
    }
    
    n_classes = len(probabilities)
    
    for q in q_thresholds:
        pred_set = build_prediction_set(probabilities, cost_matrix, q)
        set_size = len(pred_set)
        
        results["set_sizes"].append(set_size)
        results["prediction_sets"].append(pred_set)
        results["coverage_efficiency"].append(set_size / n_classes)
    
    return results



