"""
Debugger for Differentiable Coherent Factuality

A centralized debugger class that handles all debug output for the
differentiable conformal factuality pipeline, keeping the main code clean.
"""

import torch
from typing import Dict, Any, List


class FactualityDebugger:
    """
    Centralized debugger for differentiable coherent factuality operations.

    Usage:
        debugger = FactualityDebugger(enabled=True)
        debugger.soft_keep(risk, tau_list, margin, result)
    """

    def __init__(self, enabled: bool = False):
        """
        Initialize the debugger.

        Args:
            enabled: if True, debug output will be printed
        """
        self.enabled = enabled

    def _print_section(self, title: str, width: int = 60):
        """Print a section header."""
        if self.enabled:
            print("\n" + "="*width)
            print(f"=== {title} ===")
            print("="*width)

    def _print_subsection(self, title: str):
        """Print a subsection header."""
        if self.enabled:
            print(f"\n--- {title} ---")

    def soft_keep(
        self,
        risk: torch.Tensor,
        tau_list: torch.Tensor,
        temp: float,
        margin: torch.Tensor,
        result: torch.Tensor
    ):
        """Debug output for soft_keep operation."""
        if not self.enabled:
            return

        self._print_section("soft_keep")
        print(f"Temperature: {temp}")
        print(f"Risk shape: {risk.shape}, Tau list shape: {tau_list.shape}")
        print(f"Risk values: {risk}")
        print(f"Tau list: {tau_list}")
        print(f"Margin (tau - risk) shape: {margin.shape}")
        print(f"Margin min/max: {margin.min().item():.4f} / {margin.max().item():.4f}")
        print(f"Result (keep probs) min/max: {result.min().item():.4f} / {result.max().item():.4f}")

    def ancestor_coherence(
        self,
        scores: torch.Tensor,
        ancestors: torch.Tensor,
        gamma: float,
        eps: float,
        weights: torch.Tensor,
        log_scores: torch.Tensor,
        log_coherent: torch.Tensor
    ):
        """Debug output for ancestor_coherence operation."""
        if not self.enabled:
            return

        self._print_section("ancestor_coherence")
        print(f"Gamma (ancestor decay): {gamma}")
        print(f"Eps: {eps}")
        print(f"Input scores shape: {scores.shape}")
        print(f"Ancestors matrix shape: {ancestors.shape}")
        print(f"Ancestor matrix:\n{ancestors.int()}")
        print(f"Weight matrix (after normalization):\n{weights}")
        print(f"Log scores min/max: {log_scores.min().item():.4f} / {log_scores.max().item():.4f}")
        print(f"Log coherent min/max: {log_coherent.min().item():.4f} / {log_coherent.max().item():.4f}")

    def validity_negatives(
        self,
        log_coherent: torch.Tensor,
        labels: torch.Tensor,
        eps: float,
        n_neg: float,
        probs: torch.Tensor,
        log_terms: torch.Tensor,
        logQ: torch.Tensor
    ):
        """Debug output for size_invariant_validity_negatives operation."""
        if not self.enabled:
            return

        self._print_section("size_invariant_validity_negatives")
        print(f"Eps: {eps}")
        print(f"Labels shape: {labels.shape}")
        print(f"Labels: {labels}")
        print(f"Number of negative claims: {n_neg:.0f}")
        print(f"Log coherent shape: {log_coherent.shape}")
        print(f"Probs (exp of log_coherent) min/max: {probs.min().item():.4f} / {probs.max().item():.4f}")
        print(f"Log terms (log(1-p)) min/max: {log_terms.min().item():.4f} / {log_terms.max().item():.4f}")
        print(f"LogQ shape: {logQ.shape}")
        print(f"LogQ values: {logQ}")

    def violation_from_logQ(
        self,
        log_validity: torch.Tensor,
        eps: float,
        z_min: float,
        z_max: float,
        z_hat: torch.Tensor,
        V: torch.Tensor
    ):
        """Debug output for violation_from_logQ operation."""
        if not self.enabled:
            return

        self._print_section("violation_from_logQ")
        print(f"Eps: {eps}")
        print(f"Log validity (input): {log_validity}")
        print(f"z_min: {z_min:.4f}, z_max: {z_max:.4f}")
        print(f"Normalized z_hat: {z_hat}")
        print(f"Violation V (1 - z_hat): {V}")

    def soft_supremum(
        self,
        tau_list: torch.Tensor,
        V: torch.Tensor,
        beta: float,
        lambda_: float,
        eps: float,
        tau_hat: torch.Tensor,
        V_hat: torch.Tensor,
        s: torch.Tensor,
        pi: torch.Tensor,
        tau_tilde: torch.Tensor
    ):
        """Debug output for soft_supremum_from_violation operation."""
        if not self.enabled:
            return

        self._print_section("soft_supremum_from_violation")
        print(f"Beta (softmax sharpness): {beta}")
        print(f"Lambda (violation penalty): {lambda_}")
        print(f"Eps: {eps}")
        print(f"Tau list: {tau_list}")
        print(f"Violation V: {V}")
        print(f"Normalized tau_hat: {tau_hat}")
        print(f"Normalized V_hat: {V_hat}")
        print(f"Score s = tau_hat - {lambda_:.2f}*V_hat: {s}")
        print(f"Softmax weights pi: {pi}")
        print(f"Weighted tau_tilde: {tau_tilde.item():.4f}")

    def compute_risk(
        self,
        scores: torch.Tensor,
        C: float,
        beta_mix: float,
        scalar_noise: float | None,
        initial_risk: torch.Tensor,
        risk_after_mix: torch.Tensor | None,
        final_risk: torch.Tensor,
        node_details: List[Dict[str, Any]] | None = None
    ):
        """Debug output for compute_risk operation."""
        if not self.enabled:
            return

        self._print_section("compute_risk")
        print(f"C: {C}")
        print(f"Beta_mix: {beta_mix}")
        print(f"Scalar_noise: {scalar_noise}")
        print(f"Input scores: {scores}")
        print(f"Initial risk (C - scores): {initial_risk}")

        if beta_mix > 0 and node_details:
            for detail in node_details:
                print(f"  Node {detail['node']}: descendants={detail['descendants']}, "
                      f"median={detail['median']:.4f}, "
                      f"old_risk={detail['old_risk']:.4f}, new_risk={detail['new_risk']:.4f}")
            print(f"Risk after beta_mix: {risk_after_mix}")

        if scalar_noise is not None:
            print(f"Risk after scalar_noise: {final_risk}")

        print(f"Final risk: {final_risk}")

    def build_tau_grid(
        self,
        risk: torch.Tensor,
        margin: float,
        tau_unique: torch.Tensor,
        tau_min: float,
        tau_max: float,
        tau_list: torch.Tensor
    ):
        """Debug output for build_tau_grid operation."""
        if not self.enabled:
            return

        self._print_section("build_tau_grid")
        print(f"Margin: {margin}")
        print(f"Risk values: {risk}")
        print(f"Risk min: {risk.min().item():.4f}, max: {risk.max().item():.4f}")
        print(f"Unique tau values: {tau_unique.tolist()}")
        print(f"Tau min (risk.min - margin): {tau_min:.4f}")
        print(f"Tau max (risk.max + margin): {tau_max:.4f}")
        print(f"Final tau grid (with sentinels): {tau_list.tolist()}")
        print(f"Tau grid size: {len(tau_list)}")

    def soft_quantile(
        self,
        values: torch.Tensor,
        q: float,
        regularization_strength: float,
        target_idx: float,
        weights: torch.Tensor,
        sorted_vals: torch.Tensor,
        result: torch.Tensor
    ):
        """Debug output for soft_quantile operation."""
        if not self.enabled:
            return

        self._print_section("soft_quantile")
        print(f"Quantile q: {q}")
        print(f"Regularization strength: {regularization_strength}")
        print(f"Values: {values}")
        print(f"Target index: {target_idx:.4f}")
        print(f"Weights: {weights}")
        print(f"Sorted values: {sorted_vals}")
        print(f"Result: {result.item():.4f}")

    def nonconformity_score_start(
        self,
        noise_val: float,
        C: float,
        beta_mix: float,
        margin: float,
        temp: float,
        beta: float,
        gamma: float,
        lambda_: float,
        eps_keep: float,
        eps_val: float,
        eps_violation: float
    ):
        """Debug output for start of nonconformity score computation."""
        if not self.enabled:
            return

        self._print_section("COMPUTING NONCONFORMITY SCORE")
        print(f"Hyperparameters:")
        print(f"  C={C}, beta_mix={beta_mix}, noise={noise_val}")
        print(f"  margin={margin}, temp={temp}, beta={beta}")
        print(f"  gamma={gamma}, lambda_={lambda_}")
        print(f"  eps_keep={eps_keep}, eps_val={eps_val}, eps_violation={eps_violation}")

    def nonconformity_score_summary(
        self,
        risk: torch.Tensor,
        tau_list: torch.Tensor,
        tau_risk: float,
        t_score: float,
        C: float
    ):
        """Debug output for nonconformity score summary."""
        if not self.enabled:
            return

        self._print_section("NONCONFORMITY SCORE SUMMARY")
        print(f"Risk values per claim:")
        for idx, val in enumerate(risk):
            print(f"  Claim {idx}: risk={val.item():.4f}")
        print(f"\nTau grid size: {len(tau_list)}")
        print(f"Soft supremum (risk τ): {tau_risk:.4f}")
        print(f"Soft boundary (score t): {t_score:.4f}")
        print(f"Nonconformity score: {tau_risk:.4f}")

    def prob_space(
        self,
        p_coherent: torch.Tensor,
        Q: torch.Tensor
    ):
        """Debug output for probability-space values."""
        if not self.enabled:
            return

        self._print_section("DEBUG (prob-space)")
        print(f"Coherent probs p_coherent [n,T]:\n{p_coherent}")
        print(f"Per-τ validity Q_τ [T]:\n{Q}")

    def calibration_start(
        self,
        n_examples: int,
        alpha: float,
        C: float,
        beta_mix: float,
        temp: float,
        beta: float,
        gamma: float,
        lambda_: float
    ):
        """Debug output for start of calibration."""
        if not self.enabled:
            return

        self._print_section("CALIBRATION START", width=80)
        print(f"Number of calibration examples: {n_examples}")
        print(f"Alpha (miscoverage): {alpha}")
        print(f"Hyperparameters: C={C}, beta_mix={beta_mix}, temp={temp}, beta={beta}")
        print(f"                gamma={gamma}, lambda_={lambda_}")

    def calibration_example(self, i: int, total: int):
        """Debug output for calibration example."""
        if not self.enabled:
            return
        self._print_subsection(f"Calibration example {i}/{total}")

    def calibration_quantile(
        self,
        taus: torch.Tensor,
        q: float,
        tau_tilde: torch.Tensor
    ):
        """Debug output for calibration quantile."""
        if not self.enabled:
            return

        print(f"\nNonconformity scores: {taus}")
        print(f"Quantile level q: {q:.4f}")
        print(f"\nCalibrated threshold tau_tilde: {tau_tilde.item():.4f}")

    def calibration_end(self):
        """Debug output for end of calibration."""
        if not self.enabled:
            return

        print("="*80)
        print("=== CALIBRATION END ===")
        print("="*80 + "\n")

    def prediction_start(
        self,
        n_examples: int,
        tau_tilde: torch.Tensor
    ):
        """Debug output for start of prediction."""
        if not self.enabled:
            return

        self._print_section("PREDICTION START")
        print(f"Number of test examples: {n_examples}")
        print(f"Threshold tau_tilde: {tau_tilde.item() if tau_tilde.ndim > 0 else tau_tilde:.4f}")

    def prediction_example(self, i: int, total: int):
        """Debug output for prediction example."""
        if not self.enabled:
            return
        self._print_subsection(f"Test example {i}/{total}")

    def prediction_results(
        self,
        risks: torch.Tensor,
        p_keep: torch.Tensor,
        coherent_probs: torch.Tensor
    ):
        """Debug output for prediction results."""
        if not self.enabled:
            return

        print(f"Risks: {risks}")
        print(f"Soft keep probs: {p_keep.squeeze(-1)}")
        print(f"Coherent probs: {coherent_probs}")

    def prediction_end(self):
        """Debug output for end of prediction."""
        if not self.enabled:
            return

        print("="*60)
        print("=== PREDICTION END ===")
        print("="*60 + "\n")

    def retention_loss_start(
        self,
        n_examples: int,
        reduction: str
    ):
        """Debug output for start of retention loss computation."""
        if not self.enabled:
            return

        self._print_section("compute_soft_retention_loss")
        print(f"Number of examples: {n_examples}")
        print(f"Reduction: {reduction}")

    def retention_loss_example(
        self,
        i: int,
        retention: float
    ):
        """Debug output for retention loss example."""
        if not self.enabled:
            return

        print(f"Example {i}: retention={retention:.4f}")

    def retention_loss_end(
        self,
        total_retention: float,
        loss: float
    ):
        """Debug output for end of retention loss computation."""
        if not self.enabled:
            return

        print(f"Total retention: {total_retention:.4f}")
        print(f"Loss (negative retention): {loss:.4f}")


class TrainingDebugger:
    """
    Debugger for training pipeline operations.

    Handles debug output for training epochs, validation, and evaluation.
    """

    def __init__(self, enabled: bool = False):
        """
        Initialize the training debugger.

        Args:
            enabled: if True, debug output will be printed
        """
        self.enabled = enabled

    def training_start(self, n_epochs: int, alpha: float, lr: float, hyperparams: Dict[str, Any]):
        """Debug output for start of training."""
        if not self.enabled:
            return

        print("\n" + "="*70)
        print("TRAINING START")
        print("="*70)
        print(f"Epochs: {n_epochs}, α: {alpha}, learning rate: {lr}")
        print(f"Hyperparameters: {hyperparams}")

    def epoch_progress(
        self,
        epoch: int,
        n_epochs: int,
        train_loss: float,
        val_loss: float,
        val_coverage: float,
        val_marginal_coverage: float,
        val_claims: float
    ):
        """Debug output for epoch progress."""
        if not self.enabled:
            return

        print(
            f"Epoch {epoch:3d}/{n_epochs} | "
            f"Loss: {train_loss:.4f}/{val_loss:.4f} | "
            f"Cov: {val_coverage:.3f} (marg: {val_marginal_coverage:.3f}) | "
            f"Claims: {val_claims:.2f}"
        )

    def early_stopping(self, epoch: int, patience: int):
        """Debug output for early stopping."""
        if not self.enabled:
            return

        print(f"\nEarly stopping at epoch {epoch} (patience: {patience})")

    def hard_validation_start(self, n_trials: int, beta: float, target_coverage: float):
        """Debug output for start of hard validation."""
        if not self.enabled:
            return

        print("\n" + "="*70)
        print("HARD CONFORMAL VALIDATION")
        print("="*70)
        print(f"Trials: {n_trials}, β: {beta}, target coverage: {target_coverage:.3f}")

    def hard_validation_results(
        self,
        coverage: float,
        claims_retained: float,
        precision: float,
        target_coverage: float
    ):
        """Debug output for hard validation results."""
        if not self.enabled:
            return

        print(f"\nResults:")
        print(f"  Coverage: {coverage:.3f} (target: {target_coverage:.3f})")
        print(f"  Claims retained: {claims_retained:.3f}")
        print(f"  Precision: {precision:.3f}")

        if coverage >= target_coverage:
            print(f"  ✓ Coverage guarantee MET ({coverage:.3f} ≥ {target_coverage:.3f})")
        else:
            print(f"  ✗ Coverage guarantee NOT MET ({coverage:.3f} < {target_coverage:.3f})")

    def soft_validation_results(
        self,
        coverage: float,
        claims_retained: float,
        loss: float
    ):
        """Debug output for soft validation results (reference)."""
        if not self.enabled:
            return

        print("\n" + "="*70)
        print("SOFT VALIDATION (Reference)")
        print("="*70)
        print(f"  Coverage: {coverage:.3f}")
        print(f"  Claims retained: {claims_retained:.3f}")
        print(f"  Loss: {loss:.4f}")

    def training_complete(self, best_epoch: int, total_epochs: int):
        """Debug output for end of training."""
        if not self.enabled:
            return

        print("\n" + "="*70)
        print("TRAINING COMPLETE")
        print("="*70)
        print(f"Best model from epoch {best_epoch}/{total_epochs}")

    def baseline_evaluation_start(self, n_trials: int, beta: float, alpha: float):
        """Debug output for baseline evaluation."""
        if not self.enabled:
            return

        print(f"\nEvaluating baseline: {n_trials} trials, β={beta}, α={alpha}")

    def baseline_evaluation_results(self, coverage: float, claims: float):
        """Debug output for baseline results."""
        if not self.enabled:
            return

        print(f"Baseline results: coverage={coverage:.3f}, claims={claims:.3f}")


class EvaluationDebugger:
    """
    Debugger for dataset-wide evaluation and comparison tasks.

    Used for aggregating metrics across entire datasets and multiple alpha values,
    providing progress tracking and summary statistics.

    Usage:
        debugger = EvaluationDebugger(enabled=True)
        debugger.calibration_dataset_start(n_examples=50, alphas=[0.05, 0.1])
        for idx in range(n_examples):
            debugger.calibration_example_progress(idx, n_examples, ...)
    """

    def __init__(self, enabled: bool = False):
        """
        Initialize the evaluation debugger.

        Args:
            enabled: if True, debug output will be printed
        """
        self.enabled = enabled

    def _print_header(self, title: str, width: int = 80):
        """Print a main header."""
        if self.enabled:
            print("\n" + "="*width)
            print(title.center(width))
            print("="*width)

    def _print_section(self, title: str, width: int = 80):
        """Print a section header."""
        if self.enabled:
            print("\n" + "-"*width)
            print(title)
            print("-"*width)

    # ==================================================================
    # CALIBRATION COMPARISON
    # ==================================================================

    def calibration_dataset_start(self, n_examples: int, alphas: list):
        """Starting calibration comparison across dataset."""
        if not self.enabled:
            return

        self._print_header("CALIBRATION COMPARISON: HARD vs SOFT")
        print(f"Dataset size: {n_examples} examples")
        print(f"Alpha levels: {alphas}")
        print(f"Testing: How close are soft and hard tau thresholds?")

    def calibration_alpha_start(self, alpha: float, n_examples: int):
        """Starting calibration for one alpha level."""
        if not self.enabled:
            return

        self._print_section(f"Alpha = {alpha} ({n_examples} examples)")

    def calibration_example_progress(self, idx: int, n_examples: int,
                                     hard_tau: float, soft_tau: float,
                                     rel_diff: float):
        """Progress update per example."""
        if not self.enabled:
            return

        # Print every 10 examples or first/last
        if idx % 10 == 0 or idx == 0 or idx == n_examples - 1:
            print(f"  [{idx+1}/{n_examples}] hard={hard_tau:.4f}, soft={soft_tau:.4f}, "
                  f"rel_error={rel_diff:.2%}")

    def calibration_alpha_summary(self, alpha: float, mean_error: float,
                                  std_error: float, median_error: float,
                                  n_examples: int):
        """Summary for one alpha level."""
        if not self.enabled:
            return

        print(f"\nSummary for α={alpha}:")
        print(f"  Mean relative error:   {mean_error:.2%} ± {std_error:.2%}")
        print(f"  Median relative error: {median_error:.2%}")
        print(f"  Samples:               {n_examples}")

    # ==================================================================
    # PREDICTION COMPARISON
    # ==================================================================

    def prediction_dataset_start(self, n_examples: int, alphas: list):
        """Starting prediction comparison across dataset."""
        if not self.enabled:
            return

        self._print_header("PREDICTION COMPARISON: HARD vs SOFT")
        print(f"Dataset size: {n_examples} examples")
        print(f"Alpha levels: {alphas}")
        print(f"Testing: How well do soft and hard prediction sets agree?")

    def prediction_alpha_start(self, alpha: float, n_examples: int):
        """Starting prediction comparison for one alpha."""
        if not self.enabled:
            return

        self._print_section(f"Alpha = {alpha} ({n_examples} examples)")

    def prediction_example_progress(self, idx: int, n_examples: int,
                                    accuracy: float, exact_match: bool):
        """Progress per example."""
        if not self.enabled:
            return

        # Print every 10 examples
        if idx % 10 == 0 or idx == 0 or idx == n_examples - 1:
            match_str = "✓" if exact_match else "✗"
            print(f"  [{idx+1}/{n_examples}] accuracy={accuracy:.3f} exact_match={match_str}")

    def prediction_alpha_summary(self, alpha: float, mean_accuracy: float,
                                 mean_precision: float, mean_recall: float,
                                 exact_match_rate: float, n_examples: int):
        """Summary for one alpha."""
        if not self.enabled:
            return

        print(f"\nSummary for α={alpha}:")
        print(f"  Mean accuracy:    {mean_accuracy:.3f}")
        print(f"  Mean precision:   {mean_precision:.3f}")
        print(f"  Mean recall:      {mean_recall:.3f}")
        print(f"  Exact match rate: {exact_match_rate:.1%}")
        print(f"  Samples:          {n_examples}")

    # ==================================================================
    # TRAINING EVALUATION
    # ==================================================================

    def training_dataset_start(self, n_trials: int, alphas: list):
        """Starting training evaluation across dataset."""
        if not self.enabled:
            return

        self._print_header("TRAINING EVALUATION: LEARNED vs BASELINE")
        print(f"Trials per alpha: {n_trials}")
        print(f"Alpha levels: {alphas}")
        print(f"Testing: Does learned model improve claim retention?")

    def training_alpha_start(self, alpha: float, n_trials: int):
        """Starting training evaluation for one alpha."""
        if not self.enabled:
            return

        self._print_section(f"Alpha = {alpha} ({n_trials} trials)")

    def training_trial_progress(self, trial_idx: int, n_trials: int,
                               learned_coverage: float, learned_claims: float,
                               baseline_claims: float, improvement: float):
        """Progress update per trial."""
        if not self.enabled:
            return

        # Print every trial for training (not too many)
        print(f"  [{trial_idx+1}/{n_trials}] learned={learned_claims:.3f}, "
              f"baseline={baseline_claims:.3f}, improvement={improvement:+.3f}, "
              f"coverage={learned_coverage:.3f}")

    def training_alpha_summary(self, alpha: float, n_trials: int,
                              avg_learned_coverage: float,
                              avg_learned_claims: float,
                              avg_baseline_claims: float,
                              avg_improvement: float):
        """Summary for one alpha."""
        if not self.enabled:
            return

        improvement_pct = (avg_improvement / avg_baseline_claims * 100) if avg_baseline_claims > 0 else 0
        print(f"\nSummary for α={alpha} ({n_trials} trials):")
        print(f"  Learned coverage: {avg_learned_coverage:.3f}")
        print(f"  Learned claims:   {avg_learned_claims:.3f}")
        print(f"  Baseline claims:  {avg_baseline_claims:.3f}")
        print(f"  Improvement:      {avg_improvement:+.3f} ({improvement_pct:+.1f}%)")

    def training_evaluation_complete(self, n_alphas: int):
        """Training evaluation complete."""
        if not self.enabled:
            return

        print(f"\n✓ Training evaluation complete for {n_alphas} alpha levels")
        print("  Results saved to tests/plots/training/")

    # ==================================================================
    # OVERALL SUMMARY
    # ==================================================================

    def calibration_evaluation_complete(self, n_alphas: int):
        """Calibration evaluation complete."""
        if not self.enabled:
            return

        print(f"\n✓ Calibration evaluation complete for {n_alphas} alpha levels")
        print("  Results saved to tests/plots/calibration/")

    def prediction_evaluation_complete(self, n_alphas: int):
        """Prediction evaluation complete."""
        if not self.enabled:
            return

        print(f"\n✓ Prediction evaluation complete for {n_alphas} alpha levels")
        print("  Results saved to tests/plots/prediction/")
