"""
CNCRC Conformal Quantile Calibration

This module implements the conformal calibration logic to determine the quantile 
threshold q based on non-conformity scores from a calibration set. This threshold 
is critical for controlling the risk level in the CNCRC framework.

Theoretical Background:
The conformal prediction framework provides distribution-free finite-sample 
guarantees. Given calibration scores s(x_i, y_i) and desired risk level α, 
the quantile q is computed using:

q = Quantile(scores, ceil((n+1)(1-α))/n)

This formula ensures the coverage guarantee: P(s(X,Y) ≤ q) ≥ 1-α
"""
import numpy as np
from typing import Union, List, Tuple, Optional, Dict, Any
import logging
import warnings

logger = logging.getLogger(__name__)


def calibrate_quantile(
    scores: np.ndarray, 
    alpha: float
) -> float:
    """
    Calculate the conformal quantile threshold for given non-conformity scores.
    
    This implements the core conformal prediction calibration formula:
    q = Quantile(scores, ceil((n+1)(1-α))/n)
    
    The formula provides distribution-free finite-sample coverage guarantees,
    ensuring that P(s(X,Y) ≤ q) ≥ 1-α for future test points.
    
    Args:
        scores: 1D array of non-conformity scores s(x_i, y_i) from calibration set
        alpha: Desired miscoverage risk level (e.g., 0.1 for 90% coverage)
        
    Returns:
        Quantile threshold q for prediction set construction
        
    Raises:
        ValueError: If alpha not in (0,1) or scores array is invalid
        
    Example:
        >>> scores = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
        >>> q = calibrate_quantile(scores, alpha=0.1)
        >>> print(f"Threshold: {q}")
        
    References:
        - Vovk et al. "Algorithmic Learning in a Random World" (2005)
        - Shafer & Vovk "A Tutorial on Conformal Prediction" (2008)
        - Lei et al. "Distribution-Free Predictive Inference For Regression" (2018)
    """
    # Input validation
    if not isinstance(scores, np.ndarray):
        scores = np.asarray(scores, dtype=np.float64)
    
    if len(scores) == 0:
        raise ValueError("Scores array cannot be empty")
    
    if not 0 < alpha < 1:
        raise ValueError(f"Alpha must be in (0,1), got {alpha}")
    
    if np.any(np.isnan(scores)):
        raise ValueError("Scores array contains NaN values")
    
    if np.any(np.isinf(scores)):
        raise ValueError("Scores array contains infinite values")
    
    n = len(scores)
    
    # Calculate the quantile level using the conformal prediction formula
    # This ensures finite-sample coverage guarantee
    quantile_level = np.ceil((n + 1) * (1 - alpha)) / n
    
    # Ensure quantile level doesn't exceed 1.0
    quantile_level = min(1.0, quantile_level)
    
    # Calculate quantile using 'higher' method for coverage guarantee
    # The 'higher' method ensures we don't underestimate the threshold
    q = np.quantile(scores, quantile_level, method='higher')
    
    # Log calibration information
    logger.info(f"Calibration: n={n}, alpha={alpha:.3f}, quantile_level={quantile_level:.3f}, q={q:.6f}")
    
    return float(q)


def calibrate_quantile_detailed(
    scores: np.ndarray,
    alpha: float
) -> Dict[str, Any]:
    """
    Extended calibration function that returns detailed information.
    
    Args:
        scores: Non-conformity scores from calibration set
        alpha: Desired miscoverage risk level
        
    Returns:
        Dictionary containing:
        - 'quantile': The calibrated threshold q
        - 'quantile_level': The empirical quantile level used
        - 'n_calibration': Number of calibration samples
        - 'alpha': Input risk level
        - 'theoretical_coverage': Expected coverage (1-alpha)
        - 'scores_stats': Statistics of calibration scores
    """
    # Basic validation
    if not isinstance(scores, np.ndarray):
        scores = np.asarray(scores, dtype=np.float64)
    
    if len(scores) == 0:
        raise ValueError("Scores array cannot be empty")
    
    if not 0 < alpha < 1:
        raise ValueError(f"Alpha must be in (0,1), got {alpha}")
    
    n = len(scores)
    
    # Calculate quantile level
    quantile_level = np.ceil((n + 1) * (1 - alpha)) / n
    quantile_level = min(1.0, quantile_level)
    
    # Calculate quantile
    q = np.quantile(scores, quantile_level, method='higher')
    
    # Calculate scores statistics
    scores_stats = {
        'min': float(np.min(scores)),
        'max': float(np.max(scores)),
        'mean': float(np.mean(scores)),
        'std': float(np.std(scores)),
        'median': float(np.median(scores)),
        'q25': float(np.quantile(scores, 0.25)),
        'q75': float(np.quantile(scores, 0.75))
    }
    
    return {
        'quantile': float(q),
        'quantile_level': float(quantile_level),
        'n_calibration': int(n),
        'alpha': float(alpha),
        'theoretical_coverage': float(1 - alpha),
        'scores_stats': scores_stats
    }


def validate_calibration_coverage(
    calibration_scores: np.ndarray,
    test_scores: np.ndarray,
    alpha: float,
    quantile: Optional[float] = None
) -> Dict[str, float]:
    """
    Validate the empirical coverage of conformal calibration on test data.
    
    Args:
        calibration_scores: Scores used for calibration
        test_scores: Independent test scores for validation
        alpha: Risk level used in calibration
        quantile: Pre-computed quantile (if None, will compute from calibration_scores)
        
    Returns:
        Dictionary with coverage statistics:
        - 'empirical_coverage': Actual coverage on test set
        - 'theoretical_coverage': Expected coverage (1-alpha)
        - 'coverage_gap': Difference between empirical and theoretical
        - 'n_test': Number of test samples
        - 'n_covered': Number of test samples covered
    """
    if quantile is None:
        quantile = calibrate_quantile(calibration_scores, alpha)
    
    test_scores = np.asarray(test_scores)
    n_test = len(test_scores)
    
    if n_test == 0:
        raise ValueError("Test scores array cannot be empty")
    
    # Count how many test scores are below the quantile
    n_covered = np.sum(test_scores <= quantile)
    empirical_coverage = n_covered / n_test
    theoretical_coverage = 1 - alpha
    coverage_gap = empirical_coverage - theoretical_coverage
    
    return {
        'empirical_coverage': float(empirical_coverage),
        'theoretical_coverage': float(theoretical_coverage),
        'coverage_gap': float(coverage_gap),
        'n_test': int(n_test),
        'n_covered': int(n_covered),
        'quantile_used': float(quantile)
    }


def multi_alpha_calibration(
    scores: np.ndarray,
    alphas: List[float]
) -> Dict[float, float]:
    """
    Compute quantiles for multiple risk levels simultaneously.
    
    Useful for generating calibration curves and studying the effect 
    of different risk levels.
    
    Args:
        scores: Non-conformity scores from calibration set
        alphas: List of risk levels to calibrate for
        
    Returns:
        Dictionary mapping alpha -> quantile
    """
    scores = np.asarray(scores, dtype=np.float64)
    
    if len(scores) == 0:
        raise ValueError("Scores array cannot be empty")
    
    results = {}
    for alpha in alphas:
        if not 0 < alpha < 1:
            warnings.warn(f"Skipping invalid alpha: {alpha}")
            continue
        
        results[alpha] = calibrate_quantile(scores, alpha)
    
    return results


def adaptive_calibration(
    scores: np.ndarray,
    alpha: float,
    target_coverage: Optional[float] = None,
    max_iterations: int = 10,
    tolerance: float = 0.01
) -> Dict[str, Any]:
    """
    Adaptive calibration that adjusts alpha to achieve target empirical coverage.
    
    This is useful when you want to achieve a specific coverage rate on the 
    calibration set itself, which might differ from the theoretical guarantee.
    
    Args:
        scores: Non-conformity scores from calibration set  
        alpha: Initial risk level
        target_coverage: Desired empirical coverage (if None, uses 1-alpha)
        max_iterations: Maximum adjustment iterations
        tolerance: Convergence tolerance for coverage
        
    Returns:
        Dictionary with final calibration results and iteration history
    """
    scores = np.asarray(scores, dtype=np.float64)
    n = len(scores)
    
    if target_coverage is None:
        target_coverage = 1 - alpha
    
    if not 0 < target_coverage < 1:
        raise ValueError(f"Target coverage must be in (0,1), got {target_coverage}")
    
    current_alpha = alpha
    history = []
    
    for iteration in range(max_iterations):
        # Compute current quantile
        q = calibrate_quantile(scores, current_alpha)
        
        # Compute empirical coverage on calibration set
        empirical_coverage = np.sum(scores <= q) / n
        coverage_error = empirical_coverage - target_coverage
        
        history.append({
            'iteration': iteration,
            'alpha': current_alpha,
            'quantile': q,
            'empirical_coverage': empirical_coverage,
            'coverage_error': coverage_error
        })
        
        # Check convergence
        if abs(coverage_error) <= tolerance:
            break
        
        # Adjust alpha (simple gradient step)
        alpha_adjustment = -coverage_error * 0.5  # Simple heuristic
        current_alpha = np.clip(current_alpha + alpha_adjustment, 0.001, 0.999)
    
    final_result = history[-1] if history else {}
    
    return {
        'final_alpha': current_alpha,
        'final_quantile': final_result.get('quantile', np.nan),
        'final_coverage': final_result.get('empirical_coverage', np.nan),
        'target_coverage': target_coverage,
        'converged': abs(final_result.get('coverage_error', float('inf'))) <= tolerance,
        'iterations': len(history),
        'history': history
    }


def calibration_diagnostics(
    scores: np.ndarray,
    alpha: float,
    n_bootstrap: int = 1000
) -> Dict[str, Any]:
    """
    Comprehensive diagnostics for calibration quality.
    
    Args:
        scores: Non-conformity scores
        alpha: Risk level
        n_bootstrap: Number of bootstrap samples for confidence intervals
        
    Returns:
        Dictionary with diagnostic information
    """
    scores = np.asarray(scores, dtype=np.float64)
    n = len(scores)
    
    # Basic calibration
    q = calibrate_quantile(scores, alpha)
    
    # Bootstrap confidence interval for quantile
    bootstrap_quantiles = []
    rng = np.random.RandomState(42)  # Fixed seed for reproducibility
    
    for _ in range(n_bootstrap):
        bootstrap_scores = rng.choice(scores, size=n, replace=True)
        bootstrap_q = calibrate_quantile(bootstrap_scores, alpha)
        bootstrap_quantiles.append(bootstrap_q)
    
    bootstrap_quantiles = np.array(bootstrap_quantiles)
    
    # Calculate confidence intervals
    ci_lower = np.quantile(bootstrap_quantiles, 0.025)
    ci_upper = np.quantile(bootstrap_quantiles, 0.975)
    
    # Effective sample size (for finite-sample corrections)
    effective_n = (n + 1) * (1 - alpha)
    
    return {
        'quantile': float(q),
        'quantile_ci_lower': float(ci_lower),
        'quantile_ci_upper': float(ci_upper),
        'quantile_std': float(np.std(bootstrap_quantiles)),
        'n_calibration': int(n),
        'alpha': float(alpha),
        'effective_sample_size': float(effective_n),
        'finite_sample_correction': float(np.ceil(effective_n) / n),
        'bootstrap_samples': int(n_bootstrap)
    }