"""
Runtime invariant checks for router training pipeline.

These checks are designed to be cheap and run only for the first N steps
to catch silent failures early without impacting training performance.
"""

from typing import List, Dict, Any, Optional, Set
import torch


class RouterDebugChecker:
    """
    Validates router training invariants during the first N global steps.
    
    Usage:
        checker = RouterDebugChecker(max_check_steps=20)
        
        # In training loop:
        if checker.should_check(global_step):
            checker.check_candidate_invariants(...)
            checker.check_batch_invariants(...)
    """
    
    def __init__(self, max_check_steps: int = 20):
        """
        Initialize debug checker.
        
        Args:
            max_check_steps: Number of initial steps to run checks (after this, checks are disabled)
        """
        self.max_check_steps = max_check_steps
        self.checks_run = 0
        self.warnings_issued: Set[str] = set()
    
    def should_check(self, global_step: int) -> bool:
        """
        Determine if checks should run for this step.
        
        Args:
            global_step: Current training step
        
        Returns:
            True if checks should run, False otherwise
        """
        return global_step < self.max_check_steps
    
    def _warn_once(self, key: str, message: str):
        """Issue a warning only once per unique key."""
        if key not in self.warnings_issued:
            print(f"  ⚠️ [Router Debug] {message}")
            self.warnings_issued.add(key)
    
    def check_semantic_batch_invariants(
        self,
        domains: List[str],
        domains_per_batch: int,
        global_step: int,
        allow_extra_domains: int = 1,
    ) -> Dict[str, Any]:
        """
        Check semantic batching invariants.
        
        Invariants:
        1. If domains_per_batch == 1, batch should have exactly 1 unique normalized domain
           (or at most 2 if mix_replay allows top-up from another domain)
        2. Domain purity: at least 80% of examples should be from dominant domain
        
        Args:
            domains: List of domain strings from batch
            domains_per_batch: Expected number of domains per batch
            global_step: Current training step
            allow_extra_domains: Number of extra domains allowed (for replay mixing)
        
        Returns:
            Dict with check results
        """
        if not domains:
            self._warn_once("empty_domains", f"Step {global_step}: Empty domains list in batch")
            return {"passed": False, "reason": "empty_domains"}
        
        # Normalize domains for comparison
        from ..model_selection_carve.model_registry import normalize_domain
        normalized = [normalize_domain(d) for d in domains]
        
        # Count unique domains
        unique_domains = set(normalized)
        num_unique = len(unique_domains)
        
        # Count dominant domain percentage
        dominant_domain = max(set(normalized), key=normalized.count)
        dominant_count = normalized.count(dominant_domain)
        dominant_pct = (dominant_count / len(normalized)) * 100
        
        results = {
            "passed": True,
            "num_unique_domains": num_unique,
            "expected_domains": domains_per_batch,
            "dominant_pct": dominant_pct,
            "batch_size": len(domains),
        }
        
        # Check 1: Number of unique domains
        expected_max = domains_per_batch + allow_extra_domains
        if num_unique > expected_max:
            self._warn_once(
                f"domain_count_step_{global_step}",
                f"Step {global_step}: Expected ≤{expected_max} domains, got {num_unique}. "
                f"Domains: {sorted(unique_domains)[:5]}..."
            )
            results["passed"] = False
            results["reason"] = "too_many_domains"
        
        # Check 2: Domain purity (at least 80% from dominant domain when domains_per_batch == 1)
        if domains_per_batch == 1 and dominant_pct < 80.0:
            self._warn_once(
                f"domain_purity_step_{global_step}",
                f"Step {global_step}: Domain purity low: {dominant_pct:.1f}% from dominant domain '{dominant_domain}'. "
                f"Expected ≥80% for domains_per_batch=1."
            )
            results["passed"] = False
            results["reason"] = "low_purity"
        
        return results
    
    def check_candidate_set_invariants(
        self,
        candidate_indices: torch.Tensor,
        y_indices: List[int],
        K_total: int,
        global_step: int,
        check_uniqueness: bool = True,
        check_positive_at_zero: bool = True,
    ) -> Dict[str, Any]:
        """
        Check candidate set invariants per example.
        
        Invariants:
        1. Shape is [B, K_total]
        2. All candidates are unique per example (no duplicates)
        3. Positive model is at index 0 for all examples
        4. All candidate indices are valid (0 <= idx < num_models)
        
        Args:
            candidate_indices: Tensor of candidate model indices [B, K]
            y_indices: List of positive model indices [B]
            K_total: Expected number of candidates per example
            global_step: Current training step
            check_uniqueness: Whether to check for duplicate candidates
            check_positive_at_zero: Whether to check positive is at index 0
        
        Returns:
            Dict with check results
        """
        batch_size = candidate_indices.size(0)
        K = candidate_indices.size(1)
        
        results = {
            "passed": True,
            "batch_size": batch_size,
            "K": K,
            "K_expected": K_total,
        }
        
        # Check 1: Shape invariant
        if K != K_total:
            self._warn_once(
                f"candidate_shape_step_{global_step}",
                f"Step {global_step}: Expected K={K_total} candidates, got K={K}"
            )
            results["passed"] = False
            results["reason"] = "wrong_K"
            return results
        
        # Check 2: Positive at index 0
        if check_positive_at_zero and len(y_indices) == batch_size:
            positives_at_zero = candidate_indices[:, 0].cpu().tolist()
            mismatches = []
            for i, (pos_at_zero, expected_pos) in enumerate(zip(positives_at_zero, y_indices)):
                if pos_at_zero != expected_pos:
                    mismatches.append(f"example_{i}: got {pos_at_zero}, expected {expected_pos}")
            
            if mismatches:
                self._warn_once(
                    f"positive_not_at_zero_step_{global_step}",
                    f"Step {global_step}: Positive model NOT at index 0 for {len(mismatches)} examples. "
                    f"First 3: {mismatches[:3]}"
                )
                results["passed"] = False
                results["reason"] = "positive_not_at_zero"
                results["num_mismatches"] = len(mismatches)
        
        # Check 3: Uniqueness per example
        if check_uniqueness:
            num_duplicates = 0
            for i in range(batch_size):
                candidates_i = candidate_indices[i].cpu().tolist()
                unique_count = len(set(candidates_i))
                if unique_count < K:
                    num_duplicates += 1
                    if num_duplicates == 1:  # Log first occurrence
                        duplicates = [c for c in candidates_i if candidates_i.count(c) > 1]
                        self._warn_once(
                            f"duplicate_candidates_step_{global_step}",
                            f"Step {global_step}: Example {i} has {K - unique_count} duplicate candidates. "
                            f"Duplicates: {duplicates[:3]}..."
                        )
            
            if num_duplicates > 0:
                results["passed"] = False
                results["reason"] = "duplicate_candidates"
                results["num_examples_with_duplicates"] = num_duplicates
        
        # Check 4: Valid indices (non-negative)
        min_idx = candidate_indices.min().item()
        if min_idx < 0:
            self._warn_once(
                f"negative_candidate_step_{global_step}",
                f"Step {global_step}: Found negative candidate index: {min_idx}"
            )
            results["passed"] = False
            results["reason"] = "negative_index"
        
        return results
    
    def check_router_embeddings_in_optimizer(
        self,
        optimizer: torch.optim.Optimizer,
        router_model: torch.nn.Module,
        global_step: int,
    ) -> Dict[str, Any]:
        """
        Check that router parameters are included in optimizer.
        
        This is CRITICAL - if router params are not in optimizer, they won't be updated.
        
        Args:
            optimizer: Training optimizer
            router_model: Router model
            global_step: Current training step
        
        Returns:
            Dict with check results
        """
        router_param_ids = {id(p) for p in router_model.parameters()}
        
        # Check all param groups in optimizer
        optimizer_param_ids = set()
        for group in optimizer.param_groups:
            for p in group['params']:
                optimizer_param_ids.add(id(p))
        
        # Find router params in optimizer
        router_in_optimizer = router_param_ids & optimizer_param_ids
        router_missing = router_param_ids - optimizer_param_ids
        
        num_router_params = len(router_param_ids)
        num_in_optimizer = len(router_in_optimizer)
        
        results = {
            "passed": num_router_params == num_in_optimizer,
            "num_router_params": num_router_params,
            "num_in_optimizer": num_in_optimizer,
            "num_missing": len(router_missing),
        }
        
        if not results["passed"]:
            self._warn_once(
                "router_params_missing",
                f"Step {global_step}: CRITICAL: {len(router_missing)}/{num_router_params} router params "
                f"NOT in optimizer! Router will NOT be trained!"
            )
            results["reason"] = "params_not_in_optimizer"
        
        return results
    
    def check_prompt_mask_correctness(
        self,
        labels: torch.Tensor,
        attention_mask: torch.Tensor,
        prompt_mask: torch.Tensor,
        global_step: int,
    ) -> Dict[str, Any]:
        """
        Check that prompt mask extraction is correct.
        
        Invariants:
        1. prompt_mask == (labels == -100) & (attention_mask == 1)
        2. At least one prompt token per example (otherwise pooling fails)
        3. Not all tokens are prompt (need some completion for LM loss)
        
        Args:
            labels: Label tensor [B, seq_len]
            attention_mask: Attention mask [B, seq_len]
            prompt_mask: Extracted prompt mask [B, seq_len]
            global_step: Current training step
        
        Returns:
            Dict with check results
        """
        batch_size = labels.size(0)
        
        # Expected mask: prompt tokens are (labels == -100) AND (attention_mask == 1)
        expected_mask = ((labels == -100) & (attention_mask == 1)).float()
        
        # Check 1: Correctness
        matches = (prompt_mask == expected_mask).all()
        if not matches:
            self._warn_once(
                f"prompt_mask_wrong_step_{global_step}",
                f"Step {global_step}: Prompt mask does NOT match (labels==-100) & (attention_mask==1)"
            )
            return {
                "passed": False,
                "reason": "mask_mismatch",
                "batch_size": batch_size,
            }
        
        # Check 2: At least one prompt token per example
        num_prompt_tokens = prompt_mask.sum(dim=1)  # [B]
        num_zero_prompt = (num_prompt_tokens == 0).sum().item()
        
        if num_zero_prompt > 0:
            self._warn_once(
                f"zero_prompt_tokens_step_{global_step}",
                f"Step {global_step}: {num_zero_prompt}/{batch_size} examples have ZERO prompt tokens! "
                f"This is a DATA FORMATTING bug - labels should have -100 for prompt tokens."
            )
            return {
                "passed": False,
                "reason": "zero_prompt_tokens",
                "num_zero": num_zero_prompt,
                "batch_size": batch_size,
            }
        
        # Check 3: Not all tokens are prompt (should have some completion)
        seq_len = labels.size(1)
        num_all_prompt = (num_prompt_tokens == seq_len).sum().item()
        
        if num_all_prompt > 0:
            self._warn_once(
                f"all_prompt_tokens_step_{global_step}",
                f"Step {global_step}: {num_all_prompt}/{batch_size} examples have ALL tokens as prompt. "
                f"No completion tokens found."
            )
        
        results = {
            "passed": True,
            "avg_prompt_tokens": num_prompt_tokens.float().mean().item(),
            "min_prompt_tokens": num_prompt_tokens.min().item(),
            "max_prompt_tokens": num_prompt_tokens.max().item(),
            "num_all_prompt": num_all_prompt,
            "batch_size": batch_size,
        }
        
        if global_step < 3:
            print(f"  ✅ [Prompt Mask Check @ step {global_step}]: "
                  f"avg_prompt_tokens={results['avg_prompt_tokens']:.1f}, "
                  f"min={results['min_prompt_tokens']:.0f}, "
                  f"max={results['max_prompt_tokens']:.0f}")
        
        return results
    
    def check_loss_invariants(
        self,
        total_loss: torch.Tensor,
        router_loss: Optional[torch.Tensor],
        lm_loss: Optional[torch.Tensor],
        router_loss_weight: float,
        lm_loss_weight: float,
        loss_mode: str,
        global_step: int,
        tolerance: float = 1e-4,
    ) -> Dict[str, Any]:
        """
        Check loss computation invariants.
        
        Invariants:
        1. If supervised+router: total_loss == router_w*router_loss + lm_w*lm_loss (within tolerance)
        2. If router-only: lm_loss should not contribute (or lm_loss_weight == 0)
        3. All losses are finite (no NaN or Inf)
        4. Losses are non-negative
        
        Args:
            total_loss: Total combined loss
            router_loss: Router loss component (or None)
            lm_loss: LM loss component (or None)
            router_loss_weight: Weight for router loss
            lm_loss_weight: Weight for LM loss
            loss_mode: Loss mode string
            global_step: Current training step
            tolerance: Tolerance for equality check
        
        Returns:
            Dict with check results
        """
        results = {"passed": True}
        
        # Check 1: Finiteness
        if not torch.isfinite(total_loss):
            self._warn_once(
                f"loss_not_finite_step_{global_step}",
                f"Step {global_step}: Total loss is NOT finite: {total_loss.item()}"
            )
            results["passed"] = False
            results["reason"] = "loss_not_finite"
            return results
        
        if router_loss is not None and not torch.isfinite(router_loss):
            self._warn_once(
                f"router_loss_not_finite_step_{global_step}",
                f"Step {global_step}: Router loss is NOT finite: {router_loss.item()}"
            )
            results["passed"] = False
            results["reason"] = "router_loss_not_finite"
        
        if lm_loss is not None and not torch.isfinite(lm_loss):
            self._warn_once(
                f"lm_loss_not_finite_step_{global_step}",
                f"Step {global_step}: LM loss is NOT finite: {lm_loss.item()}"
            )
            results["passed"] = False
            results["reason"] = "lm_loss_not_finite"
        
        # Check 2: Loss combination correctness
        if loss_mode in ["supervised+router", "supervised+router+graph"]:
            if router_loss is not None and lm_loss is not None:
                expected_loss = router_loss_weight * router_loss + lm_loss_weight * lm_loss
                diff = abs(total_loss.item() - expected_loss.item())
                
                if diff > tolerance:
                    self._warn_once(
                        f"loss_combination_wrong_step_{global_step}",
                        f"Step {global_step}: Loss combination mismatch! "
                        f"total={total_loss.item():.6f}, "
                        f"expected={expected_loss.item():.6f} (diff={diff:.6f})"
                    )
                    results["passed"] = False
                    results["reason"] = "loss_combination_mismatch"
                    results["diff"] = diff
        
        elif loss_mode in ["router", "router+graph"]:
            # Pure router mode: LM loss should not contribute
            if router_loss is not None:
                expected_loss = router_loss_weight * router_loss
                diff = abs(total_loss.item() - expected_loss.item())
                
                if diff > tolerance:
                    self._warn_once(
                        f"router_only_loss_wrong_step_{global_step}",
                        f"Step {global_step}: Router-only mode but total_loss != router_loss*weight. "
                        f"total={total_loss.item():.6f}, expected={expected_loss.item():.6f}"
                    )
                    results["passed"] = False
                    results["reason"] = "router_only_mismatch"
        
        results["total_loss"] = total_loss.item()
        if router_loss is not None:
            results["router_loss"] = router_loss.item()
        if lm_loss is not None:
            results["lm_loss"] = lm_loss.item()
        
        return results
    
    def check_hard_mining_invariants(
        self,
        hard_cache: Dict,
        K_hard: int,
        global_step: int,
        min_cache_entries: int = 5,
    ) -> Dict[str, Any]:
        """
        Check hard negative mining invariants.
        
        Invariants:
        1. Mining runs under torch.no_grad() (checked externally)
        2. Cache fills with >= K_hard candidates for each (domain, y) key
        3. Cache is non-empty after mining
        
        Args:
            hard_cache: Hard negative cache dict
            K_hard: Expected number of hard negatives
            global_step: Current training step
            min_cache_entries: Minimum expected cache entries after mining
        
        Returns:
            Dict with check results
        """
        cache_size = len(hard_cache)
        
        results = {
            "passed": True,
            "cache_size": cache_size,
            "min_expected": min_cache_entries,
        }
        
        # Check 1: Cache is non-empty
        if cache_size == 0:
            self._warn_once(
                f"hard_cache_empty_step_{global_step}",
                f"Step {global_step}: Hard negative cache is EMPTY after mining"
            )
            results["passed"] = False
            results["reason"] = "cache_empty"
            return results
        
        # Check 2: Each cache entry has enough candidates
        entries_with_too_few = 0
        for key, candidates in hard_cache.items():
            if len(candidates) < K_hard:
                entries_with_too_few += 1
        
        if entries_with_too_few > 0:
            pct = (entries_with_too_few / cache_size) * 100
            if pct > 50:  # More than 50% have too few
                self._warn_once(
                    f"hard_cache_too_few_step_{global_step}",
                    f"Step {global_step}: {entries_with_too_few}/{cache_size} ({pct:.1f}%) cache entries "
                    f"have < {K_hard} candidates"
                )
        
        results["entries_with_too_few"] = entries_with_too_few
        results["pct_with_too_few"] = (entries_with_too_few / cache_size) * 100 if cache_size > 0 else 0
        
        return results
    
    def log_sampler_stats(
        self,
        dataset_size: int,
        batch_size: int,
        drop_last: bool,
        sampler_length: int,
        global_step: int,
    ) -> Dict[str, Any]:
        """
        Log sampler statistics and check for early-stop bugs.
        
        Invariants:
        1. sampler_length == dataset_size // batch_size (if drop_last)
        2. sampler_length == ceil(dataset_size / batch_size) (if not drop_last)
        3. If dataset_size > batch_size, sampler_length > 0
        
        Args:
            dataset_size: Total dataset size
            batch_size: Batch size
            drop_last: Whether sampler drops last incomplete batch
            sampler_length: Actual sampler length (number of batches)
            global_step: Current training step
        
        Returns:
            Dict with stats and check results
        """
        if drop_last:
            expected_length = dataset_size // batch_size
        else:
            expected_length = (dataset_size + batch_size - 1) // batch_size
        
        results = {
            "dataset_size": dataset_size,
            "batch_size": batch_size,
            "drop_last": drop_last,
            "sampler_length": sampler_length,
            "expected_length": expected_length,
            "passed": True,
        }
        
        # Check for early-stop bug
        if dataset_size > batch_size and sampler_length < expected_length / 2:
            self._warn_once(
                "sampler_early_stop",
                f"Step {global_step}: CRITICAL: Sampler yields {sampler_length} batches but expected ~{expected_length}. "
                f"Dataset size: {dataset_size}, batch size: {batch_size}. "
                f"This indicates an early-stop bug!"
            )
            results["passed"] = False
            results["reason"] = "early_stop_bug"
        
        return results

