"""
CNCRC Risk-Weighted Non-Conformity Score Calculation

This module implements the core risk-weighted non-conformity score function:
s(x,y) = max_{j≠y} P(y_j|x) · Cost(y,y_j)

This is the central calculation of the CNCRC framework that combines
model probabilities with drug interaction costs.
"""
import numpy as np
from typing import Union, List, Optional, Tuple, Dict, Any
import logging
from scipy import sparse

logger = logging.getLogger(__name__)


def calculate_risk_weighted_score(
    probabilities: np.ndarray,
    cost_matrix: np.ndarray, 
    y_true: int
) -> float:
    """
    Calculate the risk-weighted non-conformity score for a single candidate.
    
    Implements: s(x,y) = max_{j≠y} P(y_j|x) · Cost(y,y_j)
    
    Args:
        probabilities: 1D array of P(y_j|x) for all j, shape (n_classes,)
        cost_matrix: 2D array of Cost(y, y_j), shape (n_classes, n_classes)
        y_true: Index of the true label/candidate drug
        
    Returns:
        Risk-weighted score (float)
        
    Example:
        >>> probs = np.array([0.1, 0.6, 0.2, 0.1])
        >>> costs = np.array([[0.0, 0.5, 0.8, 0.3],
        ...                   [0.5, 0.0, 0.6, 0.4],
        ...                   [0.8, 0.6, 0.0, 0.7],
        ...                   [0.3, 0.4, 0.7, 0.0]])
        >>> score = calculate_risk_weighted_score(probs, costs, y_true=1)
        >>> print(f"Risk score: {score:.3f}")
    """
    # Input validation
    if not isinstance(probabilities, np.ndarray):
        probabilities = np.asarray(probabilities)
    if not isinstance(cost_matrix, np.ndarray):
        cost_matrix = np.asarray(cost_matrix)
    
    n_classes = len(probabilities)
    
    # Validate dimensions
    if cost_matrix.shape != (n_classes, n_classes):
        raise ValueError(f"Cost matrix shape {cost_matrix.shape} doesn't match "
                        f"probabilities length {n_classes}")
    
    if not 0 <= y_true < n_classes:
        raise ValueError(f"y_true={y_true} out of range [0, {n_classes-1}]")
    
    # Validate probability constraints
    if not np.allclose(np.sum(probabilities), 1.0, atol=1e-6):
        logger.warning(f"Probabilities don't sum to 1: {np.sum(probabilities):.6f}")
    
    if np.any(probabilities < 0) or np.any(probabilities > 1):
        raise ValueError("Probabilities must be in [0, 1]")
    
    # Extract cost row for y_true: Cost(y_true, j) for all j
    costs_for_y = cost_matrix[y_true, :]
    
    # Calculate weighted risks: P(y_j|x) * Cost(y_true, y_j) for all j
    weighted_risks = probabilities * costs_for_y
    
    # Exclude self-interaction: we want max_{j≠y} not max_{j}
    # Set the score for j=y_true to -inf to exclude it from max
    weighted_risks_excluding_self = weighted_risks.copy()
    weighted_risks_excluding_self[y_true] = -np.inf
    
    # Return the maximum risk
    max_risk = np.max(weighted_risks_excluding_self)
    
    # Handle edge case where all other classes have -inf (shouldn't happen in practice)
    if max_risk == -np.inf:
        logger.warning(f"All weighted risks are -inf for y_true={y_true}")
        return 0.0
    
    return float(max_risk)


def calculate_risk_weighted_scores_batch(
    probabilities: np.ndarray,
    cost_matrix: np.ndarray,
    y_candidates: Optional[List[int]] = None
) -> np.ndarray:
    """
    Calculate risk-weighted scores for multiple candidates efficiently.
    
    Args:
        probabilities: 1D array of P(y_j|x) for all j, shape (n_classes,)
        cost_matrix: 2D array of Cost(y, y_j), shape (n_classes, n_classes)  
        y_candidates: List of candidate indices. If None, compute for all classes.
        
    Returns:
        Array of risk scores, shape (len(y_candidates),)
        
    Example:
        >>> probs = np.array([0.1, 0.6, 0.2, 0.1])
        >>> costs = np.eye(4) * 0.0 + 0.5  # Example cost matrix
        >>> scores = calculate_risk_weighted_scores_batch(probs, costs)
        >>> print(f"Scores: {scores}")
    """
    if not isinstance(probabilities, np.ndarray):
        probabilities = np.asarray(probabilities)
    if not isinstance(cost_matrix, np.ndarray):
        cost_matrix = np.asarray(cost_matrix)
    
    n_classes = len(probabilities)
    
    if y_candidates is None:
        y_candidates = list(range(n_classes))
    
    # Validate inputs
    if cost_matrix.shape != (n_classes, n_classes):
        raise ValueError(f"Cost matrix shape {cost_matrix.shape} doesn't match "
                        f"probabilities length {n_classes}")
    
    # Vectorized computation for efficiency
    # weighted_risks[i, j] = P(y_j|x) * Cost(y_i, y_j)
    # We want max_{j≠i} for each i
    
    # Broadcast multiplication: (n_classes, 1) * (1, n_classes) -> (n_classes, n_classes)
    probabilities_broadcast = probabilities[np.newaxis, :]  # (1, n_classes)
    weighted_risks_matrix = cost_matrix * probabilities_broadcast  # (n_classes, n_classes)
    
    # Set diagonal to -inf to exclude self-interactions
    np.fill_diagonal(weighted_risks_matrix, -np.inf)
    
    # Take max over j for each i
    all_scores = np.max(weighted_risks_matrix, axis=1)
    
    # Handle -inf case
    all_scores = np.where(all_scores == -np.inf, 0.0, all_scores)
    
    # Return scores for requested candidates
    return all_scores[y_candidates]


def calculate_risk_weighted_score_sparse(
    probabilities: np.ndarray,
    cost_matrix: sparse.spmatrix,
    y_true: int
) -> float:
    """
    Calculate risk-weighted score with sparse cost matrix for efficiency.
    
    Args:
        probabilities: 1D array of P(y_j|x) for all j
        cost_matrix: Sparse cost matrix (scipy.sparse format)
        y_true: Index of the true label
        
    Returns:
        Risk-weighted score
    """
    if not sparse.issparse(cost_matrix):
        raise TypeError("cost_matrix must be a sparse matrix")
    
    n_classes = len(probabilities)
    
    if cost_matrix.shape != (n_classes, n_classes):
        raise ValueError(f"Cost matrix shape {cost_matrix.shape} doesn't match "
                        f"probabilities length {n_classes}")
    
    # Extract row for y_true from sparse matrix
    costs_for_y = cost_matrix[y_true, :].toarray().flatten()
    
    # Calculate weighted risks
    weighted_risks = probabilities * costs_for_y
    
    # Exclude self-interaction
    weighted_risks[y_true] = -np.inf
    
    max_risk = np.max(weighted_risks)
    
    return 0.0 if max_risk == -np.inf else float(max_risk)


def get_top_risk_contributors(
    probabilities: np.ndarray,
    cost_matrix: np.ndarray,
    y_true: int,
    top_k: int = 3
) -> List[Tuple[int, float, float, float]]:
    """
    Get the top-k contributors to the risk score for interpretability.
    
    Args:
        probabilities: 1D array of P(y_j|x) for all j
        cost_matrix: 2D cost matrix
        y_true: Index of the true label
        top_k: Number of top contributors to return
        
    Returns:
        List of tuples (class_idx, probability, cost, weighted_risk)
        sorted by weighted_risk in descending order
    """
    n_classes = len(probabilities)
    
    # Calculate all weighted risks
    costs_for_y = cost_matrix[y_true, :]
    weighted_risks = probabilities * costs_for_y
    
    # Exclude self-interaction
    weighted_risks[y_true] = -np.inf
    
    # Get top-k indices
    valid_indices = np.where(weighted_risks != -np.inf)[0]
    valid_risks = weighted_risks[valid_indices]
    
    # Sort by risk (descending)
    sorted_indices = np.argsort(valid_risks)[::-1]
    top_indices = valid_indices[sorted_indices[:top_k]]
    
    # Prepare results
    results = []
    for idx in top_indices:
        if weighted_risks[idx] > -np.inf:
            results.append((
                int(idx),                    # class index
                float(probabilities[idx]),   # probability
                float(costs_for_y[idx]),     # cost
                float(weighted_risks[idx])   # weighted risk
            ))
    
    return results


def validate_inputs(
    probabilities: np.ndarray,
    cost_matrix: np.ndarray,
    candidates: Optional[List[int]] = None
) -> Tuple[np.ndarray, np.ndarray, List[int]]:
    """
    Validate and normalize inputs for risk score calculation.
    
    Args:
        probabilities: Probability array
        cost_matrix: Cost matrix
        candidates: Optional list of candidate indices
        
    Returns:
        Tuple of (probabilities, cost_matrix, candidates)
    """
    # Convert to numpy arrays
    if not isinstance(probabilities, np.ndarray):
        probabilities = np.asarray(probabilities, dtype=np.float64)
    if not isinstance(cost_matrix, np.ndarray):
        cost_matrix = np.asarray(cost_matrix, dtype=np.float64)
    
    n_classes = len(probabilities)
    
    # Validate shapes
    if cost_matrix.shape != (n_classes, n_classes):
        raise ValueError(f"Cost matrix shape {cost_matrix.shape} doesn't match "
                        f"probabilities length {n_classes}")
    
    # Validate probability constraints
    if np.any(probabilities < 0):
        raise ValueError("Probabilities cannot be negative")
    
    if np.any(probabilities > 1):
        raise ValueError("Probabilities cannot exceed 1")
    
    prob_sum = np.sum(probabilities)
    if not np.isclose(prob_sum, 1.0, atol=1e-6):
        logger.warning(f"Probabilities sum to {prob_sum:.6f}, normalizing to 1.0")
        probabilities = probabilities / prob_sum
    
    # Validate cost matrix
    if np.any(cost_matrix < 0):
        raise ValueError("Cost matrix cannot contain negative values")
    
    # Handle candidates
    if candidates is None:
        candidates = list(range(n_classes))
    else:
        # Validate candidate indices
        for c in candidates:
            if not 0 <= c < n_classes:
                raise ValueError(f"Candidate {c} out of range [0, {n_classes-1}]")
    
    return probabilities, cost_matrix, candidates




