"""
Router debugging utilities.

This module contains all debug and diagnostic functions for router training,
separated from the main training logic for better maintainability.
"""

from typing import Dict, List, Optional, Any, Tuple
import torch
import torch.nn.functional as F
import numpy as np

from ..model_selection_carve import ModelRegistry, CandidateSetBuilder, RouterModel


def debug_router_supervision(
    router_model: RouterModel,
    logits: torch.Tensor,
    candidate_indices: torch.Tensor,
    gold_model_names: List[str],
    gold_indices: List[int],
    domains: List[str],
    registry: ModelRegistry,
    candidate_builder: Optional[CandidateSetBuilder] = None,
    hard_negative_cache: Optional[Dict] = None,
    targets: Optional[torch.Tensor] = None,
    loss: Optional[torch.Tensor] = None,
    prompt_embeddings: Optional[torch.Tensor] = None,
    global_step: int = 0,
    micro_idx: int = 0,
    debug_enabled: bool = False,
    debug_every: int = 100,
    debug_first_steps: int = 50,
    debug_strict: bool = False,
) -> None:
    """
    Comprehensive debug checks for router supervision end-to-end.
    
    What to look for:
    - Candidate/target alignment: gold model should be at candidates[0] for all examples
    - Loss correctness: CE loss computed correctly, logits match expectations
    - Gradient flow: router parameters receive gradients, prompt embeddings not detached
    - Provenance accounting: sem/hard/far percentages match actual counts
    
    Args:
        router_model: Router model instance
        logits: Router logits [B, K]
        candidate_indices: Candidate model indices [B, K]
        gold_model_names: List of gold model names [B]
        gold_indices: List of gold model indices [B] (should all be 0)
        domains: List of domains [B]
        registry: ModelRegistry instance
        candidate_builder: Optional CandidateSetBuilder for composition stats
        hard_negative_cache: Optional hard negative cache
        targets: Optional target indices [B] (should be all 0)
        loss: Optional computed loss tensor
        prompt_embeddings: Optional prompt embeddings [B, D] for gradient check
        global_step: Current global step
        micro_idx: Microbatch index (for gradient accumulation)
        debug_enabled: Whether debug is enabled
        debug_every: Run debug every N steps
        debug_first_steps: Run debug for first N steps
        debug_strict: If True, raise AssertionError on mismatches
    """
    if not debug_enabled:
        return
    
    # Check if we should run debug
    should_debug = (global_step < debug_first_steps) or (global_step % debug_every == 0)
    if not should_debug:
        return
    
    batch_size, K = logits.shape
    device = logits.device
    
    print(f"\n{'='*80}")
    print(f"[RouterSupervisionDebug] Step {global_step}, Microbatch {micro_idx}")
    print(f"{'='*80}")
    print(f"  Batch size: {batch_size}, K: {K}")
    print(f"  Logits shape: {logits.shape}, dtype: {logits.dtype}, device: {logits.device}")
    
    # Candidate/target alignment checks
    _check_candidate_target_alignment(
        logits, candidate_indices, gold_model_names, gold_indices, 
        domains, registry, batch_size, K, debug_strict
    )
    
    # Router loss correctness checks
    _check_loss_correctness(logits, targets, loss, batch_size, K, device)
    
    # Gradient flow checks
    _check_gradient_flow(router_model, prompt_embeddings)
    
    # Provenance accounting sanity
    _check_provenance_accounting(
        candidate_indices, gold_indices, domains, candidate_builder,
        hard_negative_cache, batch_size, K, registry
    )
    
    print(f"\n{'='*80}\n")


def _check_candidate_target_alignment(
    logits: torch.Tensor,
    candidate_indices: torch.Tensor,
    gold_model_names: List[str],
    gold_indices: List[int],
    domains: List[str],
    registry: ModelRegistry,
    batch_size: int,
    K: int,
    debug_strict: bool,
) -> None:
    """Check candidate/target alignment."""
    print(f"\n  [A] Candidate/Target Alignment Checks:")
    num_examples_to_check = min(3, batch_size)
    
    for i in range(num_examples_to_check):
        gold_model = gold_model_names[i]
        gold_idx = gold_indices[i]
        candidates_i = candidate_indices[i].cpu().tolist()
        logits_i = logits[i]
        domain = domains[i] if i < len(domains) else "unknown"
        
        candidate_models = [registry.idx2model.get(c, f"unknown_idx_{c}") for c in candidates_i]
        candidate_at_0 = candidate_models[0]
        gold_model_idx_from_registry = registry.model2idx.get(gold_model, None)
        candidate_idx_at_0 = candidates_i[0]
        
        print(f"\n    Example {i} (domain: {domain}):")
        print(f"      Gold model: '{gold_model}' (idx={gold_model_idx_from_registry})")
        print(f"      candidates[0]: '{candidate_at_0}' (idx={candidate_idx_at_0})")
        
        if candidate_idx_at_0 != gold_model_idx_from_registry:
            error_msg = (
                f"      ❌ MISMATCH: candidates[0] != gold_model!\n"
                f"         candidates[0] = {candidate_idx_at_0} ('{candidate_at_0}')\n"
                f"         gold_model = {gold_model_idx_from_registry} ('{gold_model}')"
            )
            print(error_msg)
            if debug_strict:
                raise AssertionError(error_msg)
        else:
            print(f"      ✓ candidates[0] == gold_model")
        
        # Find gold_index in candidates
        try:
            gold_index_in_candidates = candidates_i.index(gold_model_idx_from_registry)
        except ValueError:
            gold_index_in_candidates = None
        
        if gold_index_in_candidates is None:
            error_msg = f"      ❌ Gold model not found in candidates!"
            print(error_msg)
            if debug_strict:
                raise AssertionError(error_msg)
        elif gold_index_in_candidates != 0:
            error_msg = (
                f"      ❌ Gold model at index {gold_index_in_candidates}, expected 0!\n"
                f"         Full candidates: {candidates_i[:10]}..."
            )
            print(error_msg)
            if debug_strict:
                raise AssertionError(error_msg)
        else:
            print(f"      ✓ Gold model at index 0")
        
        # Top-5 predictions
        top5_logits, top5_indices = torch.topk(logits_i, k=min(5, K))
        print(f"      Top-5 predicted indices: {top5_indices.cpu().tolist()}")
        print(f"      Top-5 candidate models:")
        for rank, (logit_val, idx) in enumerate(zip(top5_logits, top5_indices), 1):
            model_name = candidate_models[idx.item()]
            is_correct = "✓" if idx.item() == 0 else "✗"
            print(f"        {rank}. {model_name} (logit={logit_val.item():.4f}) {is_correct}")
        
        # Gold rank via sorting logits
        sorted_indices = torch.argsort(logits_i, descending=True)
        gold_rank = (sorted_indices == 0).nonzero(as_tuple=True)[0].item() + 1
        print(f"      Gold rank (via logits): {gold_rank}/{K}")


def _check_loss_correctness(
    logits: torch.Tensor,
    targets: Optional[torch.Tensor],
    loss: Optional[torch.Tensor],
    batch_size: int,
    K: int,
    device: torch.device,
) -> None:
    """Check router loss correctness."""
    print(f"\n  [B] Router Loss Correctness Checks:")
    
    if targets is None:
        targets = torch.zeros(batch_size, dtype=torch.long, device=device)
    
    # Check 1: Explicitly compute CE loss two ways
    ce_loss_method1 = F.cross_entropy(logits, targets, reduction="none")
    log_softmax = F.log_softmax(logits, dim=-1)
    ce_loss_method2 = -log_softmax[torch.arange(batch_size, device=device), targets]
    
    max_diff = (ce_loss_method1 - ce_loss_method2).abs().max().item()
    print(f"    CE loss computed two ways:")
    print(f"      Method 1 (F.cross_entropy): mean={ce_loss_method1.mean().item():.6f}, std={ce_loss_method1.std().item():.6f}")
    print(f"      Method 2 (manual log-softmax): mean={ce_loss_method2.mean().item():.6f}, std={ce_loss_method2.std().item():.6f}")
    print(f"      Max abs diff: {max_diff:.2e}")
    
    if max_diff > 1e-6:
        print(f"      ⚠️ WARNING: Methods differ by > 1e-6!")
        print(f"      Sample values (first 3):")
        for i in range(min(3, batch_size)):
            print(f"        Example {i}: method1={ce_loss_method1[i].item():.6f}, method2={ce_loss_method2[i].item():.6f}")
    else:
        print(f"      ✓ Methods match within tolerance")
    
    if loss is not None:
        expected_loss = ce_loss_method1.mean()
        actual_loss_val = loss.item()
        loss_diff = abs(actual_loss_val - expected_loss.item())
        print(f"    Computed loss: {actual_loss_val:.6f}")
        print(f"    Expected loss (mean of per-example): {expected_loss.item():.6f}")
        print(f"    Diff: {loss_diff:.2e}")
        if loss_diff > 1e-5:
            print(f"      ⚠️ WARNING: Loss mismatch!")
    
    # Check 2: Summary stats
    logits_pos = logits[:, 0]
    logits_neg = logits[:, 1:]
    
    print(f"\n    Logits summary stats:")
    print(f"      Positive (idx 0): mean={logits_pos.mean().item():.4f}, std={logits_pos.std().item():.4f}")
    print(f"      Negative (idx 1:): mean={logits_neg.mean().item():.4f}, std={logits_neg.std().item():.4f}")
    print(f"      All logits: mean={logits.mean().item():.4f}, std={logits.std().item():.4f}")
    
    probs = F.softmax(logits, dim=-1)
    prob_pos = probs[:, 0]
    prob_neg = probs[:, 1:]
    
    print(f"    Probabilities summary stats:")
    print(f"      Positive (idx 0): mean={prob_pos.mean().item():.6f}")
    print(f"      Negative (idx 1:): mean={prob_neg.mean().item():.6f}")
    
    max_neg_logits = logits_neg.max(dim=1)[0]
    pos_better = (logits_pos > max_neg_logits).float().mean().item()
    print(f"    Fraction where logits_pos > max(logits_neg): {pos_better:.2%}")


def _check_gradient_flow(
    router_model: RouterModel,
    prompt_embeddings: Optional[torch.Tensor],
) -> None:
    """Check gradient flow."""
    print(f"\n  [C] Gradient Flow Checks:")
    
    router_params = []
    for name, param in router_model.named_parameters():
        if param.requires_grad:
            router_params.append((name, param))
    
    print(f"    Router parameters with requires_grad=True: {len(router_params)}")
    if router_params:
        print(f"    First 5 parameter names:")
        for name, param in router_params[:5]:
            print(f"      {name}: shape={param.shape}")
    
    has_grads = False
    grad_norms = []
    for name, param in router_params:
        if param.grad is not None:
            has_grads = True
            grad_norm = param.grad.norm().item()
            grad_norms.append(grad_norm)
            if len(grad_norms) <= 5:
                print(f"      {name}: grad_norm={grad_norm:.6f}")
    
    if has_grads:
        print(f"    ✓ At least one router parameter has gradients")
        if grad_norms:
            print(f"    Grad norm stats: min={min(grad_norms):.6f}, max={max(grad_norms):.6f}, mean={sum(grad_norms)/len(grad_norms):.6f}")
            has_nan = any(not np.isfinite(g) for g in grad_norms)
            if has_nan:
                print(f"    ❌ WARNING: Some gradients are NaN/Inf!")
            else:
                print(f"    ✓ All gradients are finite")
    else:
        print(f"    ⚠️ WARNING: No router parameters have gradients yet (may be before backward)")
    
    if prompt_embeddings is not None:
        requires_grad = prompt_embeddings.requires_grad
        print(f"    Prompt embeddings requires_grad: {requires_grad}")
        if not requires_grad:
            print(f"      ❌ WARNING: Prompt embeddings are detached! This will block gradient flow.")
        else:
            print(f"      ✓ Prompt embeddings require gradients")


def _check_provenance_accounting(
    candidate_indices: torch.Tensor,
    gold_indices: List[int],
    domains: List[str],
    candidate_builder: Optional[CandidateSetBuilder],
    hard_negative_cache: Optional[Dict],
    batch_size: int,
    K: int,
    registry: ModelRegistry,
) -> None:
    """Check provenance accounting sanity."""
    print(f"\n  [E] Provenance Accounting Sanity:")
    
    if candidate_builder and gold_indices and domains:
        total_positive = 0
        total_hard = 0
        total_semantic = 0
        total_far = 0
        total_other = 0
        
        for i in range(min(batch_size, len(gold_indices), len(domains))):
            candidates = candidate_indices[i].cpu().tolist()
            stats = candidate_builder.get_composition_stats(
                candidates=candidates,
                y_idx=gold_indices[i],
                domain=domains[i],
                hard_negative_cache=hard_negative_cache,
            )
            total_positive += stats.get("positive", 0)
            total_hard += stats.get("hard", 0)
            total_semantic += stats.get("semantic", 0)
            total_far += stats.get("far", 0)
            total_other += stats.get("other", 0)
        
        total = total_positive + total_hard + total_semantic + total_far + total_other
        if total > 0:
            print(f"    Composition counts (across batch):")
            print(f"      Positive: {total_positive} ({total_positive/total*100:.1f}%)")
            print(f"      Hard: {total_hard} ({total_hard/total*100:.1f}%)")
            print(f"      Semantic: {total_semantic} ({total_semantic/total*100:.1f}%)")
            print(f"      Far: {total_far} ({total_far/total*100:.1f}%)")
            print(f"      Other: {total_other} ({total_other/total*100:.1f}%)")
            print(f"      Total: {total}")

