"""
Utilities for generating consolidated UQ results tables.

This module provides functions to run comprehensive UQ analysis across
different output types (beam vs bayes), space metrics, evaluation metrics,
targets, and estimators.
"""

from typing import Callable, Dict, List, Optional, Tuple
from dataclasses import dataclass
import numpy as np
import pandas as pd

from structured_llmuq.utils.uq_metrics import auc_cont
from lm_polygraph.ue_metrics.pred_rej_area import PredictionRejectionArea
from lm_polygraph.ue_metrics.ue_metric import get_random_scores


# =============================================================================
# Data Classes
# =============================================================================

@dataclass
class UQResult:
    """Container for a single UQ evaluation result."""
    model: str
    output: str  # 'beam' or 'bayes'
    space_metric: str  # The metric used for minimum Bayes action, or 'N/A' for beam
    eval_metric: str  # The metric used for evaluation (e.g., 'Hamming', 'F1')
    target: str  # 'Distance' or 'Risk'
    estimator: str  # The UQ estimator name
    auc: float
    prr: float


@dataclass
class OutputConfig:
    """Configuration for an output type."""
    name: str  # 'beam' or 'bayes'
    space_metric: str  # The metric used for minimum Bayes action, or 'N/A'
    prediction_key: str  # Key in all_metrics to get predictions


# =============================================================================
# Loss Functions
# =============================================================================

def hamming_loss(s1: np.ndarray, s2: np.ndarray) -> float:
    """
    Compute Hamming distance (number of differing positions).
    
    Args:
        s1: First binary array
        s2: Second binary array
        
    Returns:
        Number of positions where s1 and s2 differ
    """
    return float(np.sum(s1 != s2))


def f1_loss(s1: np.ndarray, s2: np.ndarray) -> float:
    """
    Compute F1 loss (1 - F1 score).
    
    Args:
        s1: First binary array (ground truth)
        s2: Second binary array (prediction)
        
    Returns:
        1 - F1 score (lower is better for loss)
    """
    tp = np.sum((s1 == 1) & (s2 == 1))
    fp = np.sum((s1 == 0) & (s2 == 1))
    fn = np.sum((s1 == 1) & (s2 == 0))
    
    precision = tp / (tp + fp) if (tp + fp) > 0 else 0.0
    recall = tp / (tp + fn) if (tp + fn) > 0 else 0.0
    f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0.0
    
    return 1.0 - f1


def exact_match_loss(s1: np.ndarray, s2: np.ndarray) -> float:
    """
    Compute exact match loss (0-1 loss).
    
    Args:
        s1: First array (ground truth)
        s2: Second array (prediction)
        
    Returns:
        0.0 if arrays are exactly equal, 1.0 otherwise
    """
    return 0.0 if np.array_equal(s1, s2) else 1.0


# =============================================================================
# Simplex Loss Functions (for probability distributions)
# =============================================================================

def kl_divergence(p: np.ndarray, q: np.ndarray, epsilon: float = 1e-10) -> float:
    """Compute KL divergence KL(p||q) = sum_i p_i * log(p_i / q_i).
    
    Args:
        p: Probability distribution (reference/ground truth)
        q: Probability distribution (approximation/prediction)
        epsilon: Small constant to avoid log(0)
    
    Returns:
        KL divergence (non-negative, 0 if p==q)
    """
    p = np.asarray(p, dtype=np.float64) + epsilon
    q = np.asarray(q, dtype=np.float64) + epsilon
    # Renormalize
    p = p / p.sum()
    q = q / q.sum()
    return float(np.sum(p * np.log(p / q)))


def reverse_kl_divergence(p: np.ndarray, q: np.ndarray, epsilon: float = 1e-10) -> float:
    """Compute reverse KL divergence KL(q||p) = sum_i q_i * log(q_i / p_i)."""
    return kl_divergence(q, p, epsilon)


def js_divergence(p: np.ndarray, q: np.ndarray, epsilon: float = 1e-10) -> float:
    """Compute Jensen-Shannon divergence: symmetric version of KL.
    
    JS(p||q) = 0.5 * KL(p||m) + 0.5 * KL(q||m) where m = 0.5*(p+q)
    """
    p = np.asarray(p, dtype=np.float64) + epsilon
    q = np.asarray(q, dtype=np.float64) + epsilon
    p = p / p.sum()
    q = q / q.sum()
    m = 0.5 * (p + q)
    return float(0.5 * kl_divergence(p, m, epsilon) + 0.5 * kl_divergence(q, m, epsilon))


def l2_distance(p: np.ndarray, q: np.ndarray) -> float:
    """Compute L2 (Euclidean) distance between probability distributions."""
    return float(np.sqrt(np.sum((np.asarray(p) - np.asarray(q)) ** 2)))


def l1_distance(p: np.ndarray, q: np.ndarray) -> float:
    """Compute L1 (Manhattan/Total Variation) distance."""
    return float(np.sum(np.abs(np.asarray(p) - np.asarray(q))))


def hellinger_distance(p: np.ndarray, q: np.ndarray, epsilon: float = 1e-10) -> float:
    """Compute Hellinger distance between probability distributions.
    
    H(p,q) = sqrt(0.5 * sum_i (sqrt(p_i) - sqrt(q_i))^2)
    """
    p = np.asarray(p, dtype=np.float64) + epsilon
    q = np.asarray(q, dtype=np.float64) + epsilon
    p = p / p.sum()
    q = q / q.sum()
    return float(np.sqrt(0.5 * np.sum((np.sqrt(p) - np.sqrt(q)) ** 2)))


# Default loss functions registry
DEFAULT_LOSS_FUNCTIONS: Dict[str, Callable] = {
    'Hamming': hamming_loss,
    'F1': f1_loss,
    'Exact Match': exact_match_loss,
}

# Simplex loss functions registry
SIMPLEX_LOSS_FUNCTIONS: Dict[str, Callable] = {
    'KL': kl_divergence,
    'Reverse KL': reverse_kl_divergence,
    'JS': js_divergence,
    'L2': l2_distance,
    'L1': l1_distance,
    'Hellinger': hellinger_distance,
}


# =============================================================================
# Helper Functions
# =============================================================================

def compute_bayes_risk(
    prediction: np.ndarray, 
    samples: np.ndarray, 
    loss_func: Callable
) -> float:
    """
    Compute the Bayes risk for a given loss function and samples.
    
    Args:
        prediction: The prediction array
        samples: Array of samples, shape (n_samples, n_features)
        loss_func: Loss function to compute risk
        
    Returns:
        Expected loss (Bayes risk)
    """
    n_samples = samples.shape[0]
    risks = np.array([loss_func(prediction, samples[i]) for i in range(n_samples)])
    return float(np.mean(risks))


def compute_prr(
    ue: np.ndarray, 
    metrics: np.ndarray,
    prr_calculator: Optional[PredictionRejectionArea] = None,
    max_rejection: float = 0.5
) -> float:
    """
    Compute normalized Prediction Rejection Rate.
    
    Args:
        ue: Uncertainty estimates
        metrics: Performance metrics (negative of error, so higher is better)
        prr_calculator: Optional pre-initialized PRR calculator
        max_rejection: Maximum rejection rate for PRR calculation
        
    Returns:
        Normalized PRR score
    """
    if prr_calculator is None:
        prr_calculator = PredictionRejectionArea(max_rejection=max_rejection)
    
    mean_val = prr_calculator(ue, metrics)
    oracle = prr_calculator(-metrics, metrics)
    random = get_random_scores(prr_calculator, metrics)
    
    denominator = oracle - random
    if abs(denominator) < 1e-10:
        return 0.0
    
    return (mean_val - random) / denominator


def get_estimator_values(
    all_metrics: Dict, 
    key: str,
    estimator_keys: Optional[Dict[str, str]] = None
) -> Dict[str, np.ndarray]:
    """
    Extract all estimator values for a given model/dataset key.
    
    Args:
        all_metrics: Dictionary containing metrics for all models
        key: The model/dataset key
        estimator_keys: Optional mapping from estimator names to keys in all_metrics
                       Default: {'SE': 'se', 'MSP': 'msp', 'KLE': 'kle', 
                                'COCOA': 'cocoa', 'SAR': 'sar', 'PTrue': 'p_true'}
        
    Returns:
        Dictionary mapping estimator names to their values as numpy arrays
    """
    if estimator_keys is None:
        estimator_keys = {
            'SE': 'se',
            'MSP': 'msp',
            'KLE': 'kle',
            'COCOA': 'cocoa',
            'SAR': 'sar',
            'PTrue': 'p_true'
        }
    
    estimators = {}
    for est_name, est_key in estimator_keys.items():
        if est_key in all_metrics[key] and len(all_metrics[key][est_key]) > 0:
            estimators[est_name] = np.array(all_metrics[key][est_key])
    
    return estimators


# =============================================================================
# Main Analysis Functions
# =============================================================================

def run_uq_analysis(
    all_dfs: Dict,
    all_metrics: Dict,
    model_key_conversion: Optional[Dict[str, str]] = None,
    output_configs: Optional[List[OutputConfig]] = None,
    loss_functions: Optional[Dict[str, Callable]] = None,
    estimator_keys: Optional[Dict[str, str]] = None,
    prr_max_rejection: float = 0.5
) -> List[UQResult]:
    """
    Run comprehensive UQ analysis across all combinations.
    
    Args:
        all_dfs: Dictionary of DataFrames containing the data for each model
        all_metrics: Dictionary containing pre-computed metrics for each model
        model_key_conversion: Optional mapping from raw model keys to display names
        output_configs: List of OutputConfig objects defining the output types to analyze.
                       Default: beam, bayes (Hamming), bayes (F1)
        loss_functions: Dictionary mapping metric names to loss functions.
                       Default: Hamming and F1
        estimator_keys: Optional mapping for estimator extraction
        prr_max_rejection: Maximum rejection rate for PRR calculation
        
    Returns:
        List of UQResult objects containing all results
    """
    if model_key_conversion is None:
        model_key_conversion = {}
    
    if output_configs is None:
        output_configs = [
            OutputConfig('beam', 'N/A', 'z_beam'),
            OutputConfig('bayes', 'Hamming', 'z_hamming_bayes'),
            OutputConfig('bayes', 'F1', 'z_f1_bayes'),
        ]
    
    if loss_functions is None:
        loss_functions = DEFAULT_LOSS_FUNCTIONS.copy()
    
    # Initialize PRR calculator
    prr_calculator = PredictionRejectionArea(max_rejection=prr_max_rejection)
    
    results = []
    
    for key, df in all_dfs.items():
        model_name = model_key_conversion.get(key, key)
        
        # Extract samples and ground truth
        z_samples = df["metrics_arrays"].apply(lambda x: x["generations_arrays"]).tolist()
        z_star = all_metrics[key]["z_star"]
        
        # Get estimator values
        estimators = get_estimator_values(all_metrics, key, estimator_keys)
        
        # Iterate over all output configurations
        for config in output_configs:
            # Check if the prediction key exists
            if config.prediction_key not in all_metrics[key]:
                continue
                
            z_predictions = all_metrics[key][config.prediction_key]
            
            # Calculate Bayes risk estimates for each evaluation metric
            bayes_risk_estimates = {}
            for eval_metric, loss_func in loss_functions.items():
                br_values = []
                for i in range(len(z_samples)):
                    br = compute_bayes_risk(z_predictions[i], z_samples[i], loss_func)
                    br_values.append(br)
                bayes_risk_estimates[eval_metric] = np.array(br_values)
            
            # Calculate true targets for each evaluation metric
            for eval_metric, loss_func in loss_functions.items():
                # True Bayes Risk: E[loss(z_star, z_sample)]
                true_risk = np.array([
                    compute_bayes_risk(z_star[i], z_samples[i], loss_func)
                    for i in range(len(z_samples))
                ])
                
                # True Distance: loss(z_star, z_prediction)
                true_distance = np.array([
                    loss_func(z_star[i], z_predictions[i])
                    for i in range(len(z_samples))
                ])
                
                targets = {
                    'Risk': true_risk,
                    'Distance': true_distance
                }
                
                # Combine BayesRisk estimator with other estimators
                all_estimators = {'BayesRisk': bayes_risk_estimates[eval_metric], **estimators}
                
                # Evaluate each estimator against each target
                for target_name, target_values in targets.items():
                    for estimator_name, estimator_values in all_estimators.items():
                        if len(estimator_values) == 0:
                            continue
                        
                        # Calculate AUC and PRR
                        auc = auc_cont(target_values, estimator_values)
                        prr = compute_prr(
                            estimator_values, 
                            -1 * target_values, 
                            prr_calculator
                        )
                        
                        results.append(UQResult(
                            model=model_name,
                            output=config.name,
                            space_metric=config.space_metric,
                            eval_metric=eval_metric,
                            target=target_name,
                            estimator=estimator_name,
                            auc=auc,
                            prr=prr
                        ))
    
    return results


def results_to_dataframe(results: List[UQResult]) -> pd.DataFrame:
    """
    Convert a list of UQResult objects to a pandas DataFrame.
    
    Args:
        results: List of UQResult objects
        
    Returns:
        DataFrame with columns: Model, Output, SpaceMetric, EvaluationMetric,
                               Target, Estimator, AUC, PRR
    """
    return pd.DataFrame([
        {
            'Model': r.model,
            'Output': r.output,
            'SpaceMetric': r.space_metric,
            'EvaluationMetric': r.eval_metric,
            'Target': r.target,
            'Estimator': r.estimator,
            'AUC': r.auc,
            'PRR': r.prr
        }
        for r in results
    ])


def create_uq_table(
    all_dfs: Dict,
    all_metrics: Dict,
    model_key_conversion: Optional[Dict[str, str]] = None,
    output_configs: Optional[List[OutputConfig]] = None,
    loss_functions: Optional[Dict[str, Callable]] = None,
    estimator_keys: Optional[Dict[str, str]] = None,
    prr_max_rejection: float = 0.5,
) -> pd.DataFrame:
    """
    Convenience function to create the full UQ results table.
    
    This is the main entry point for creating consolidated UQ tables.
    
    Args:
        all_dfs: Dictionary of DataFrames containing the data for each model
        all_metrics: Dictionary containing pre-computed metrics for each model
        model_key_conversion: Optional mapping from raw model keys to display names
        output_configs: List of OutputConfig objects defining the output types
        loss_functions: Dictionary mapping metric names to loss functions
        estimator_keys: Optional mapping for estimator extraction
        prr_max_rejection: Maximum rejection rate for PRR calculation
        round_digits: Number of decimal places to round results
        
    Returns:
        DataFrame with all UQ results, sorted and rounded
        
    Example:
        >>> from structured_llmuq.utils.results_utils import create_uq_table, OutputConfig
        >>> 
        >>> # Define custom output configurations for your specific bayes estimators
        >>> output_configs = [
        ...     OutputConfig('beam', 'N/A', 'z_beam'),
        ...     OutputConfig('bayes', 'MyMetric', 'z_my_bayes'),
        ... ]
        >>> 
        >>> uq_df = create_uq_table(
        ...     all_dfs=all_dfs,
        ...     all_metrics=all_metrics,
        ...     model_key_conversion={'model_key': 'Display Name'},
        ...     output_configs=output_configs
        ... )
    """
    results = run_uq_analysis(
        all_dfs=all_dfs,
        all_metrics=all_metrics,
        model_key_conversion=model_key_conversion,
        output_configs=output_configs,
        loss_functions=loss_functions,
        estimator_keys=estimator_keys,
        prr_max_rejection=prr_max_rejection
    )
    
    df = results_to_dataframe(results)


    return df



