"""
Differentiable Coherent Factuality

This module implements a fully differentiable approximation of Coherent Factuality
for LLM outputs with reasoning structure, enabling gradient-based optimization of
claim scoring functions while preserving conformal prediction guarantees.

All functions accept an optional `debugger` parameter (FactualityDebugger instance)
for detailed logging.

Usage:
    from src.debugger import FactualityDebugger

    # Create debugger (enabled=True for debug output)
    debugger = FactualityDebugger(enabled=True)

    # Calibration
    tau = calibrate(X_cal, Y_cal, noise_cal, scorer, alpha=0.1, debugger=debugger)

    # Prediction
    probs = predict(X_test, noise_test, scorer, tau, debugger=debugger)

    # Loss for training
    loss = compute_soft_retention_loss(probs, Y_test, debugger=debugger)
"""

import torch
import torch.nn as nn
import numpy as np
import networkx as nx
from typing import Dict, List, Any, Tuple, Optional
from .debugger import FactualityDebugger
import torchsort


def soft_keep(
    scores: torch.Tensor,
    tau_list: torch.Tensor,
    temp: float = 0.2,
    debugger: FactualityDebugger | None = None
) -> torch.Tensor:
    """
    Sigmoid-smoothed threshold filtering.

    Args:
        scores: [n] risk scores
        tau_list: [T] threshold values
        temp: temperature for sigmoid smoothness
        debugger: optional debugger instance

    Returns:
        [n, T] soft keep probabilities for each node vs each tau
    """
    risk = scores
    margin = tau_list[None, :] - risk[:, None]
    result = torch.sigmoid(margin / temp)

    if debugger:
        debugger.soft_keep(risk, tau_list, temp, margin, result)

    return result


def ancestor_coherence(
    scores: torch.Tensor,
    ancestors: torch.Tensor,
    gamma: float = 1.0,
    eps: float = 1e-9,
    debugger: FactualityDebugger | None = None
) -> torch.Tensor:
    """
    Enforce ancestor coherence with optional decay.

    Args:
        scores: [n, T] per-node keep probs at each τ
        ancestors: [n, n] bool/int ancestor matrix
        gamma: decay factor (1.0=full, 0.0=none, >1.0=amplified)
        eps: epsilon for numerical stability
        debugger: optional debugger instance

    Returns:
        log_coherent: [n, T] log-space coherence-adjusted probabilities
    """
    n, T = scores.shape
    log_scores = torch.log(scores + eps)

    ancestor_with_self = ancestors | torch.eye(n, dtype=torch.bool, device=ancestors.device)

    # Weight matrix: ancestors get gamma, self gets 1.0
    weights = ancestor_with_self.float() * gamma
    weights[torch.arange(n), torch.arange(n)] = 1.0

    # Normalize per column
    weights = weights / (weights.sum(dim=0, keepdim=True) + eps)

    # Weighted average of original scores
    log_coherent = weights.T @ log_scores

    if debugger:
        debugger.ancestor_coherence(scores, ancestors, gamma, eps, weights, log_scores, log_coherent)

    return log_coherent


def size_invariant_validity_negatives(
    log_coherent: torch.Tensor,
    labels: torch.Tensor,
    eps: float = 1e-6,
    debugger: FactualityDebugger | None = None
) -> torch.Tensor:
    """
    Validity on negatives (log-space).

    Args:
        log_coherent: [n, T] log coherence probabilities
        labels: [n] binary labels
        eps: epsilon for numerical stability
        debugger: optional debugger instance

    Returns:
        log Q_τ over the τ grid: shape [T], each ≤ 0
    """
    neg_mask = (labels == 0).float().unsqueeze(1)  # [n, 1]
    N_neg = neg_mask.sum().clamp(min=1.0)
    probs = log_coherent.exp().clamp(max=1.0 - eps)
    log_terms = torch.log1p(-probs) * neg_mask     # log(1 - p)
    logQ = log_terms.sum(dim=0) / N_neg            # [T], ≤ 0

    if debugger:
        debugger.validity_negatives(log_coherent, labels, eps, N_neg.item(), probs, log_terms, logQ)

    return logQ


def violation_from_logQ(
    log_validity: torch.Tensor,
    violation_mode: str = "minmax",
    squash_temp: float = 1.0,
    eps: float = 1e-12,
    debugger: Optional['FactualityDebugger'] = None
) -> torch.Tensor:
    """
    Transform log-validity into violation scores V ∈ [0,1].

    Supports multiple transformation modes:

    1. "minmax": Min-max normalization (default, from paper)
       V = 1 - (log_validity - min) / (max - min + eps)
       - Data-driven, uses full [0,1] range
       - No additional hyperparameters

    2. "exponential": Exponential mapping (corrected sign)
       V = 1 - exp(log_validity / squash_temp)
       - Controlled by squash_temp (higher → smoother transition)
       - More negative log_validity → higher V (approaches 1)

    3. "minmax_exponential": Min-max followed by exponential
       V = exp(-(1 - z_hat) / squash_temp) where z_hat is normalized
       - Combines data-driven normalization with controllable decay

    Args:
        log_validity: [T] log validity scores (each ≤ 0)
        violation_mode: one of ["minmax", "exponential", "minmax_exponential"]
        squash_temp: temperature for exponential decay (only used in exponential modes)
                     higher values → less aggressive penalty for violations
        eps: epsilon for numerical stability in normalization
        debugger: optional debugger instance

    Returns:
        V: [T] violation scores in [0,1], where higher V = more violation
    """

    if violation_mode == "minmax":
        # Min-max normalization (original paper formulation)
        z = log_validity
        z_min = torch.min(z)
        z_max = torch.max(z)
        z_hat = (z - z_min) / (z_max - z_min + eps)
        V = 1.0 - z_hat

    elif violation_mode == "exponential":
        # Exponential mapping with correct sign
        # More negative log_validity → Higher V (approaches 1)
        # Formula: V = 1 - exp(log_validity / squash_temp)
        # Since log_validity ≤ 0:
        #   log_validity → -∞: exp(-∞/τ) → 0, so V → 1 (high violation)
        #   log_validity → 0:  exp(0/τ) → 1, so V → 0 (no violation)
        V = 1.0 - torch.exp(log_validity / squash_temp)
        z_hat = None  # Not used in this mode

    elif violation_mode == "minmax_exponential":
        # Hybrid: normalize first, then apply exponential
        z = log_validity
        z_min = torch.min(z)
        z_max = torch.max(z)
        z_hat = (z - z_min) / (z_max - z_min + eps)
        # Apply exponential to normalized inverse
        V = torch.exp(-(1.0 - z_hat) / squash_temp)

    else:
        raise ValueError(f"Unknown violation_mode: {violation_mode}. "
                        f"Must be one of ['minmax', 'exponential', 'minmax_exponential']")

    if debugger:
        z_min_val = torch.min(log_validity).item()
        z_max_val = torch.max(log_validity).item()
        z_hat_for_debug = z_hat if violation_mode in ["minmax", "minmax_exponential"] else None

        # Pass squash_temp as eps_violation for backward compatibility with debugger
        debugger.violation_from_logQ(
            log_validity,
            squash_temp,  # Passed as eps_violation parameter
            z_min_val,
            z_max_val,
            z_hat_for_debug,
            V
        )

    return V


# def violation_from_logQ(
#     log_validity: torch.Tensor,
#     tau: float = 1.0,
#     offset: float = 0.0,
#     debugger: Optional['FactualityDebugger'] = None
# ) -> torch.Tensor:
#     """
#     Transform log-validity into an increasing 'violation' curve using sigmoid:
#         V = 1 - sigmoid((log_validity + offset) / tau)

#     This maps log_validity → V ∈ [0,1] such that:
#         - More negative log_validity → Higher V (more violation)
#         - log_validity ≈ -offset → V ≈ 0.5 (transition point)
#         - As tau → 0, approaches step function at log_validity = -offset

#     Args:
#         log_validity: [T] log validity scores (each ≤ 0)
#         tau: temperature parameter controlling sharpness of transition
#              (smaller tau → sharper step function as tau → 0)
#         offset: controls the center of the sigmoid transition
#                 (positive offset → easier to avoid violations → higher retention)
#         debugger: optional debugger instance

#     Returns:
#         V: [T] ∈ [0,1], where higher V means more violation

#     Theoretical properties:
#         - As tau → 0, V → step function: V = 1 if log_validity < -offset, else 0
#         - Positive offset shifts transition right, making violations less likely
#         - Negative offset shifts transition left, making violations more likely
#         - This recovers hard conformal behavior in the limit
#     """
#     # Invert sigmoid so more negative log_validity gives higher violation
#     V = 1.0 - torch.sigmoid((log_validity + offset) / tau)

#     if debugger:
#         debugger.violation_from_logQ(
#             log_validity,
#             tau,
#             torch.min(log_validity).item(),
#             torch.max(log_validity).item(),
#             None,  # no z_hat in this version
#             V
#         )

#     return V


def soft_supremum_from_violation(
    tau_list: torch.Tensor,
    V: torch.Tensor,
    beta: float = 8.0,
    lambda_: float = 1.0,
    eps: float = 1e-12,
    debugger: FactualityDebugger | None = None
):
    """
    Soft supremum with trade-off between high τ and low violation.

    Args:
        tau_list: [T] tau grid values
        V: [T] violation scores
        beta: softmax sharpness (higher = sharper, closer to hard max)
        lambda_: violation penalty weight (>1.0=conservative, <1.0=aggressive)
        eps: epsilon for numerical stability
        debugger: optional debugger instance

    Returns:
        tau_tilde: soft threshold (weighted average)
        pi: soft weights over tau grid
    """
    norm = lambda z: (z - z.min()) / (z.max() - z.min() + eps)
    tau_hat = norm(tau_list)
    V_hat = norm(V)

    s = tau_hat - lambda_ * V_hat
    pi = torch.softmax(beta * s, dim=0)
    tau_tilde = (pi * tau_list).sum()

    if debugger:
        debugger.soft_supremum(tau_list, V, beta, lambda_, eps, tau_hat, V_hat, s, pi, tau_tilde)

    return tau_tilde, pi


def compute_risk(
    scores: torch.Tensor,
    adj: torch.Tensor,
    C: float = 6.0,
    beta_mix: float = 0.0,
    scalar_noise: float | None = None,
    debugger: FactualityDebugger | None = None
) -> torch.Tensor:
    """
    Compute risk scores from model scores.

    Args:
        scores: [n] model scores for each claim
        adj: [n, n] adjacency matrix (parent->child)
        C: constant for risk transformation (risk = C - score)
        beta_mix: mixing parameter for descendant median (0.0 = no mixing)
        scalar_noise: optional additive noise to risk
        debugger: optional debugger instance

    Returns:
        risk: [n] risk scores
    """
    risk = C - scores
    initial_risk = risk.clone()
    node_details = []

    if beta_mix > 0:
        n = risk.numel()
        r2 = risk.clone()
        for i in range(n):
            desc = torch.where(adj[i, :] == 1)[0]
            if desc.numel() > 0:
                med = risk[desc].median()
                r2[i] = (1 - beta_mix) * risk[i] + beta_mix * med
                if debugger:
                    node_details.append({
                        'node': i,
                        'descendants': desc.tolist(),
                        'median': med.item(),
                        'old_risk': risk[i].item(),
                        'new_risk': r2[i].item()
                    })
        risk_after_mix = r2.clone()
        risk = r2
    else:
        risk_after_mix = None

    if scalar_noise is not None:
        risk = risk + float(scalar_noise)

    if debugger:
        debugger.compute_risk(scores, C, beta_mix, scalar_noise, initial_risk, risk_after_mix, risk, node_details if node_details else None)

    return risk


def build_tau_grid(
    risk: torch.Tensor,
    margin: float = 20.0,
    debugger: FactualityDebugger | None = None
) -> torch.Tensor:
    """
    Build tau grid from risk scores.

    Args:
        risk: [n] risk scores for each claim
        margin: margin to add to min/max risk values
        debugger: optional debugger instance

    Returns:
        tau_list: [T] sorted tau values forming the grid
    """
    # NOTE: We detach to prevent gradient flow through unique operation, which is not differentiable
    # However, tau_min and tau_max still have gradients
    tau_unique = torch.unique(risk.detach(), sorted=True)
    tau_min = risk.min() - margin
    tau_max = risk.max() + margin
    tau_list = torch.cat([tau_min.view(1), tau_unique, tau_max.view(1)], dim=0)

    if debugger:
        debugger.build_tau_grid(risk, margin, tau_unique, tau_min.item(), tau_max.item(), tau_list)

    return tau_list


# -------------------------------
# Calibration and Prediction
# -------------------------------

def compute_nonconformity_score(
    x: Dict[str, Any],
    y: torch.Tensor,
    noise_val: float,
    scorer: nn.Module,
    C: float = 6.0,
    beta_mix: float = 0.0,
    margin: float = 20.0,
    temp: float = 0.2,
    beta: float = 8.0,
    gamma: float = 1.0,
    lambda_: float = 1.0,
    violation_mode: str = "exponential",
    squash_temp: float = 1.0,
    eps_keep: float = 1e-8,
    eps_val: float = 1e-6,
    debugger: Optional['FactualityDebugger'] = None,
    debug_probs: bool = False,
    return_cache: bool = False
) -> torch.Tensor | Tuple[torch.Tensor, Dict]:
    """
    Compute nonconformity score for one example (differentiable soft supremum pipeline).

    This is the core differentiable approximation of Coherent Factuality:
    1. Score nodes with model
    2. Compute risk values
    3. Build tau grid
    4. Soft keep probabilities
    5. Ancestor coherence
    6. Validity on negatives
    7. Violation measure
    8. Soft supremum → nonconformity score

    Args:
        x: dict containing 'ancestors' [n,n], 'adj' [n,n], 'features' [n,m]
        y: [n] labels in {0,1}
        noise_val: scalar noise to add to risk
        scorer: neural network model
        C: constant for risk transformation
        beta_mix: mixing parameter for descendant median
        margin: margin for tau grid bounds
        temp: temperature for soft keep sigmoid
        beta: softmax sharpness for soft supremum
        gamma: ancestor decay factor
        lambda_: violation penalty weight
        violation_mode: one of ["minmax", "exponential", "minmax_exponential"]
        squash_temp: temperature for exponential squashing (only used if violation_mode != "minmax")
        eps_keep: epsilon for ancestor coherence log stability
        eps_val: epsilon for validity negatives
        debugger: optional debugger instance
        debug_probs: if True, also print probability-space values
        return_cache: if True, return (tau_tilde, cache) instead of just tau_tilde

    Returns:
        tau_tilde: scalar nonconformity score (soft threshold)
        cache (optional): dict of intermediate values if return_cache=True
    """
    # Get device from model parameters (not from data tensors which are on CPU)
    # Handle models with no parameters (like ForwardScorer)
    try:
        device = next(scorer.parameters()).device
    except StopIteration:
        device = torch.device('cpu')

    ancestors = torch.as_tensor(x['ancestors'], device=device)
    y = y.to(device)

    if debugger:
        debugger.nonconformity_score_start(noise_val, C, beta_mix, margin, temp, beta, gamma, lambda_, eps_keep, eps_val, 1e-12)

    # Get scores from model - ensure features are on correct device
    features = torch.as_tensor(x['features'], device=device)
    scores = scorer.forward(features)
    scores = scores.to(device).float()

    # Compute risk
    adj = torch.as_tensor(x['adj'], device=device)
    risk = compute_risk(scores, adj, C, beta_mix, noise_val, debugger)

    # Build tau grid
    tau_list = build_tau_grid(risk, margin, debugger)
    tau_list = tau_list.to(device).sort().values

    # (1) soft keep
    p_keep = soft_keep(risk, tau_list, temp, debugger)

    # (2) ancestor coherence
    log_coherent = ancestor_coherence(p_keep, ancestors, gamma, eps_keep, debugger)

    # (3) validity on negatives
    log_validity = size_invariant_validity_negatives(log_coherent, y, eps_val, debugger)

    # (4) violation
    V = violation_from_logQ(log_validity, violation_mode=violation_mode, squash_temp=squash_temp,
                           eps=1e-12, debugger=debugger)

    # (5) soft supremum
    tau_tilde, pi = soft_supremum_from_violation(tau_list, V, beta, lambda_, eps=1e-12, debugger=debugger)

    if debugger:
        tau_risk = tau_tilde.item()
        t_score = C - tau_risk
        debugger.nonconformity_score_summary(risk, tau_list, tau_risk, t_score, C)

    if debug_probs and debugger:
        p_coherent = log_coherent.exp().clamp(max=1.0 - 1e-12)
        Q = log_validity.exp().clamp(max=1.0 - 1e-12)
        debugger.prob_space(p_coherent, Q)

    if return_cache:
        cache = {
            "risk": risk, "tau_list": tau_list,
            "p_keep": p_keep, "log_coherent": log_coherent,
            "log_validity": log_validity, "V": V, "pi": pi
        }
        return tau_tilde, cache

    return tau_tilde


def soft_quantile(
    values: torch.Tensor,
    q: float,
    regularization_strength: float = 1e-4,
    debugger: FactualityDebugger | None = None
) -> torch.Tensor:
    """
    Differentiable soft quantile using torchsort for differentiable sorting.

    Args:
        values: [n] tensor of values
        q: quantile in [0, 1]
        regularization_strength: temperature parameter for soft sorting
        debugger: optional debugger instance

    Returns:
        soft quantile value
    """
    # Ensure values is 2D for torchsort
    if values.ndim == 1:
        values = values.unsqueeze(0)

    # Save original device (torchsort doesn't support MPS, need to use CPU)
    original_device = values.device
    if original_device.type == 'mps':
        values = values.cpu()

    n = values.size(1)
    # Match hard quantile formula: ceil((n+1)*(1-alpha)) - 1 (0-indexed)
    target_idx = np.ceil((n + 1) * (1 - q)) - 1

    # Differentiable soft sort using torchsort (CPU only for now)
    sorted_vals = torchsort.soft_sort(values, regularization_strength=regularization_strength)

    # Soft indexing with Gaussian weights
    indices = torch.arange(n, dtype=torch.float32, device=values.device)
    sigma = 0.5
    weights = torch.exp(-((indices - target_idx) / sigma) ** 2)
    weights = weights / weights.sum()

    result = (weights * sorted_vals[0]).sum()

    # Move result back to original device
    if original_device.type == 'mps':
        result = result.to(original_device)

    if debugger:
        debugger.soft_quantile(values.squeeze(), q, regularization_strength, target_idx, weights, sorted_vals.squeeze(), result)

    return result


def calibrate(
    X: List[Dict[str, Any]],
    Y: List[torch.Tensor],
    noise: List[float],
    scorer: nn.Module,
    alpha: float,
    C: float = 6.0,
    beta_mix: float = 0.0,
    margin: float = 20.0,
    temp: float = 0.2,
    beta: float = 8.0,
    gamma: float = 1.0,
    lambda_: float = 1.0,
    violation_mode: str = "minmax",
    squash_temp: float = 1.0,
    eps_keep: float = 1e-8,
    eps_val: float = 1e-6,
    regularization_strength: float = 1e-2,
    debugger: Optional['FactualityDebugger'] = None
) -> torch.Tensor:
    """
    Fully differentiable calibration procedure.

    Args:
        X: list of calibration examples (graph dicts)
        Y: list of label tensors
        noise: list of noise values for each example
        scorer: neural network model
        alpha: desired miscoverage level (e.g., 0.1 for 90% coverage)
        C, beta_mix, margin, temp, beta, gamma, lambda_: hyperparameters
        eps_keep, eps_val: epsilon values
        regularization_strength: for soft quantile
        debugger: optional debugger instance

    Returns:
        tau_tilde: calibrated threshold (soft quantile of nonconformity scores)
    """
    if debugger:
        debugger.calibration_start(len(X), alpha, C, beta_mix, temp, beta, gamma, lambda_)

    taus = []

    for i, (x, y, n) in enumerate(zip(X, Y, noise)):
        if debugger:
            debugger.calibration_example(i+1, len(X))

        tau_est = compute_nonconformity_score(
            x, y, n, scorer, C, beta_mix, margin, temp, beta, gamma, lambda_,
            violation_mode, squash_temp, eps_keep, eps_val, debugger
        )
        taus.append(tau_est)

    taus_tensor = torch.stack(taus)
    n_cal = len(taus)
    q = np.ceil((n_cal + 1) * (1 - alpha)) / n_cal

    # Use SOFT quantile for differentiability
    tau_tilde = soft_quantile(taus_tensor, q, regularization_strength, debugger)

    if debugger:
        debugger.calibration_quantile(taus_tensor, q, tau_tilde)
        debugger.calibration_end()

    return tau_tilde


def predict(
    X: List[Dict[str, Any]],
    noise: List[float],
    scorer: nn.Module,
    tau_tilde: torch.Tensor,
    C: float = 6.0,
    beta_mix: float = 0.0,
    margin: float = 20.0,
    temp: float = 0.2,
    beta: float = 20.0,
    gamma: float = 1.0,
    cutoff_temp: float = 0.01,
    eps: float = 1e-8,
    debugger: Optional['FactualityDebugger'] = None
) -> List[torch.Tensor]:
    """
    Soft prediction with differentiable soft-gated weighted argmax over threshold grid.

    Implements soft relaxation of the hard Coherent Factuality prediction algorithm:
        Hard: argmax_{τ} τ  s.t. τ ≤ τ_calibrated
        Soft: w_τ = exp(beta*τ) * sigmoid((tau_alpha - τ)/z), then normalize and weight

    The soft version uses:
    - exp(beta * τ) to weight toward larger thresholds
    - sigmoid((tau_alpha - τ)/cutoff_temp) to soft-gate thresholds above tau_alpha
    As cutoff_temp → 0, sigmoid becomes a hard cutoff at tau_alpha
    As beta → ∞, this converges to selecting only the maximum threshold below tau_alpha

    This formulation is fully differentiable (no hard cutoff) while maintaining
    the same behavior in the limit.

    Args:
        X: list of test examples (dicts with 'features', 'adj', 'ancestors' tensors)
        noise: list of noise values for each example
        scorer: model that scores graph nodes
        tau_tilde: learned threshold from calibration (scalar tensor)
        C: constant for risk transformation
        beta_mix: mixing parameter for risk computation
        margin: margin for building threshold grid below tau_tilde
        temp: temperature for soft keep operation (controls sigmoid sharpness)
        beta: sharpness for soft argmax over thresholds (higher = closer to hard argmax)
        gamma: ancestor coherence decay factor
        cutoff_temp: temperature for soft cutoff gate (lower = closer to hard cutoff)
        eps: epsilon for numerical stability
        debugger: optional debugger instance

    Returns:
        predictions: list of soft probability tensors [n] for each example
    """
    if debugger:
        debugger.prediction_start(len(X), tau_tilde)

    predictions = []

    # Ensure tau_tilde is a tensor on the correct device (get device from scorer)
    # Handle models with no parameters (like ForwardScorer)
    try:
        device = next(scorer.parameters()).device
    except StopIteration:
        device = torch.device('cpu')
    if not isinstance(tau_tilde, torch.Tensor):
        tau_tilde = torch.tensor(tau_tilde, dtype=torch.float32, device=device)
    else:
        tau_tilde = tau_tilde.to(device)
    if tau_tilde.ndim == 0:
        tau_tilde = tau_tilde.unsqueeze(0)

    for i, (x, n) in enumerate(zip(X, noise)):
        if debugger:
            debugger.prediction_example(i+1, len(X))

        # Extract data and ensure all tensors are on the correct device
        features = torch.as_tensor(x['features'], device=device)  # [n_nodes, n_features] tensor
        adj = torch.as_tensor(x['adj'], device=device)  # [n_nodes, n_nodes] tensor
        ancestors = torch.as_tensor(x['ancestors'], device=device)  # [n_nodes, n_nodes] tensor

        # Get scores from model
        scores = scorer.forward(features)  # [n_nodes] tensor

        # Compute risks for all nodes
        risks = compute_risk(scores, adj, C, beta_mix, n, debugger)  # [n_nodes] tensor

        # Build threshold grid (NO hard cutoff - use full grid)
        # Create grid from min_risk - margin to max_risk + margin
        tau_grid = build_tau_grid(risks, margin, debugger)  # [T] tensor

        # VECTORIZED: Evaluate ALL thresholds at once (5-40x faster than looping)
        # This optimization computes probabilities and utilities for all T thresholds
        # in parallel, eliminating the need for a Python for-loop.
        #
        # Performance: Tested in experiments/test_predict_vectorization.py
        # - Small (10 nodes, 5 thresholds): 5x speedup
        # - Medium (50 nodes, 30 thresholds): 40x speedup
        # - Large (100 nodes, 50 thresholds): 16x speedup
        #
        # Correctness: Produces numerically identical results (diff < 1e-7)
        # Gradients: NOW FULLY differentiable - no hard cutoff!

        # Soft keep for ALL thresholds at once
        p_keep_all = soft_keep(risks, tau_grid, temp, None)  # [n_nodes, T]

        # Ancestor coherence for ALL results at once
        log_coherent_all = ancestor_coherence(p_keep_all, ancestors, gamma, eps, None)  # [n_nodes, T]
        coherent_probs_all = log_coherent_all.exp()  # [n_nodes, T]

        # Soft argmax with differentiable soft gating:
        # Hard algorithm: argmax_{τ} τ  s.t. τ ≤ τ_calibrated
        # New soft version: w_τ = exp(beta*τ) * sigmoid((tau_alpha - τ)/cutoff_temp)
        #
        # This combines:
        # - exp(beta*τ): preference for larger thresholds (as beta → ∞, picks max)
        # - sigmoid((tau_alpha - τ)/cutoff_temp): soft gate for τ < tau_alpha
        #   * When τ < tau_alpha: sigmoid(positive) ≈ 1 (keep)
        #   * When τ > tau_alpha: sigmoid(negative) ≈ 0 (reject)
        #   * As cutoff_temp → 0: sigmoid becomes hard step function
        #
        # Key advantage: FULLY DIFFERENTIABLE - no hard cutoff, fixed number of terms!

        # Extract tau_tilde value (but keep it in computation graph for gradients)
        tau_tilde_val = tau_tilde[0] if tau_tilde.ndim > 0 else tau_tilde

        # Compute unnormalized weights
        # w_τ = exp(beta*τ) * sigmoid((tau_alpha - τ)/cutoff_temp)
        tau_preference = torch.exp(beta * tau_grid)  # Prefer larger τ
        soft_gate = torch.sigmoid((tau_tilde_val - tau_grid) / cutoff_temp)  # Gate τ > tau_alpha
        unnormalized_weights = tau_preference * soft_gate  # [T]

        # Normalize weights to sum to 1
        weights = unnormalized_weights / (unnormalized_weights.sum() + eps)  # [T]

        # Weighted combination of coherent probabilities
        # This produces the final soft prediction as a weighted average of
        # the keep probabilities at each threshold
        # Number of terms T is now CONSTANT (doesn't change with tau_tilde) → differentiable!
        final_probs = (weights.unsqueeze(0) * coherent_probs_all).sum(dim=1)  # [n_nodes]

        if debugger:
            # For debugging: use the keep probs and coherent probs from the highest weighted threshold
            best_threshold_idx = weights.argmax()
            debugger.prediction_results(risks, p_keep_all[:, best_threshold_idx].unsqueeze(-1), final_probs)

        predictions.append(final_probs)

    if debugger:
        debugger.prediction_end()

    return predictions


# -------------------------------
# Loss Functions
# -------------------------------

def compute_soft_retention_loss(
    soft_probs: List[torch.Tensor],
    labels: List[torch.Tensor],
    reduction: str = 'mean',
    debugger: FactualityDebugger | None = None
) -> torch.Tensor:
    """
    Compute retention loss: maximize true claims retained.

    Args:
        soft_probs: list of [n] probability tensors for each example
        labels: list of [n] label tensors for each example
        reduction: 'mean', 'sum', or 'none'
        debugger: optional debugger instance

    Returns:
        loss: negative retention (minimize = maximize retention)
    """
    if debugger:
        debugger.retention_loss_start(len(soft_probs), reduction)

    total_retention = 0.0

    for i, (probs, y) in enumerate(zip(soft_probs, labels)):
        y = y.to(probs.device).float()

        # Sum retention of true claims (y=1)
        retention = (y * probs).sum()
        total_retention = total_retention + retention

        if debugger:
            debugger.retention_loss_example(i, retention.item())

    # Return NEGATIVE (so optimizer minimizes = maximizes retention)
    if reduction == 'mean':
        loss = -total_retention / len(soft_probs)
    elif reduction == 'sum':
        loss = -total_retention
    else:
        loss = -total_retention

    if debugger:
        debugger.retention_loss_end(total_retention.item(), loss.item())

    return loss
