import torch
import torch.nn.functional as F
from typing import Dict, Any, Optional

from hr2r.model.registry import register_loss_func, capture_init_args, get_loss_func_class
from hr2r.train import weighted_cross_entropy, fixed_cross_entropy
from transformers.utils import logging

logger = logging.get_logger(__name__)


class LossFunc:
    _is_intra_iter_loss: bool = False

    def __init__(self, **kwargs):
        self.config = kwargs

    def prepare_loss(self, batch_size, query_len, device, dtype):
        pass

    def intra_iter_loss_func(self, *args, **kwargs):
        raise NotImplementedError(
            "This loss function does not support intra-iteration loss calculation."
        )

    def final_loss_func(self, *args, **kwargs) -> torch.Tensor:
        raise NotImplementedError


@register_loss_func
@capture_init_args
class NextTokenPredLoss(LossFunc):
    _is_intra_iter_loss: bool = False

    def __init__(self, important_token_relative_weight: float = 1.0, weight_important: float = None, weight_normal: float = None, **kwargs):
        super().__init__()
        self.important_token_relative_weight = important_token_relative_weight
        self.weight_important = weight_important
        self.weight_normal = weight_normal

    def final_loss_func(
        self,
        logits: torch.Tensor,
        labels_shifted: torch.Tensor,
        iter_count: torch.Tensor,
        training: bool,
        **kwargs,
    ) -> torch.Tensor:
        num_items_in_batch = kwargs.get("num_items_in_batch", None)

        vocab_size = logits.shape[-1]

        logits = logits.float() # upcast to float to avoid precision issue, following transformers official implementation
        shift_iter = iter_count.contiguous() if not (iter_count == -1).all() else None

        shift_logits = logits.view(-1, vocab_size).float()
        shift_labels = labels_shifted.view(-1)
        shift_iter = shift_iter.view(-1) if shift_iter is not None else None

        shift_labels = shift_labels.to(shift_logits.device)
        ignore_index = -100
        has_custom_weights = (
            self.weight_important is not None and self.weight_normal is not None
        )

        if self.important_token_relative_weight == 1.0 or not training:
            return fixed_cross_entropy(
                shift_logits,
                shift_labels,
                num_items_in_batch=num_items_in_batch,
                ignore_index=ignore_index,
            )
        else:
            weight_important = (
                self.weight_important
                if has_custom_weights
                else self.important_token_relative_weight
            )
            weight_normal = self.weight_normal if has_custom_weights else 1.0

            token_weights = torch.full_like(
                shift_labels, weight_normal, dtype=shift_logits.dtype
            )
            if shift_iter is not None:
                token_weights[shift_iter > 1] = weight_important

            return weighted_cross_entropy(
                shift_logits,
                shift_labels,
                token_weights,
                num_items_in_batch=num_items_in_batch,
                ignore_index=ignore_index,
            )


@register_loss_func
@capture_init_args
class ConsistencyLoss(LossFunc):
    _is_intra_iter_loss: bool = True

    def __init__(self, **kwargs):
        from hr2r.model.recurrent_transformer import HR2RForCausalLM # import like this to avoid circular import
        self.assign_active = HR2RForCausalLM.assign_active
        super().__init__(**kwargs)
        
    def prepare_loss(self, batch_size, query_len, device, dtype):
        self.consistency_loss_per_token = torch.zeros(
            batch_size, query_len, device=device, dtype=torch.float32
        ) # noqa: always use float32 for loss

    def intra_iter_loss_func(
        self,
        active_logits: torch.Tensor,
        current_iter_mask: torch.BoolTensor,
        active_labels_shifted: torch.Tensor,
        **kwargs,
    ):
        if self.consistency_loss_per_token is None:
            raise RuntimeError(
                "Consistency loss tensor not initialized. Call `init_consistency_loss` first."
            )

        batch_size, query_len = current_iter_mask.shape
        device = active_logits.device
        active_logits = active_logits.float() # upcast to float to avoid precision issue, following transformers official implementation

        token_losses = torch.zeros(
            batch_size, query_len, device=device, dtype=active_logits.dtype
        )

        if not current_iter_mask.any() or active_labels_shifted is None:
            return torch.tensor(0.0, device=active_logits.device, dtype=active_logits.dtype)

        flat_active_logits = active_logits.view(-1, active_logits.size(-1))
        flat_active_labels = active_labels_shifted.view(-1)

        flat_losses = F.cross_entropy(
            flat_active_logits,
            flat_active_labels,
            reduction="none",
            ignore_index=-100,
        )

        active_losses_reshaped = flat_losses.view(batch_size, -1)
        self.assign_active(current_iter_mask, active_losses_reshaped, token_losses)
        self._update_consistency_loss(token_losses)
        
        return token_losses

    def _update_consistency_loss(self, token_losses):
        self.consistency_loss_per_token = token_losses + self.consistency_loss_per_token

    def final_loss_func(
        self,
        labels_shifted: torch.Tensor,
        iter_count: torch.Tensor,
        training: bool,
        **kwargs,
    ) -> torch.Tensor:
        if self.consistency_loss_per_token is None:
            raise RuntimeError(
                "Consistency loss tensor not initialized or already consumed."
            )

        num_items_in_batch = kwargs.get("num_items_in_batch", None)

        valid_mask = (labels_shifted != -100) & (iter_count > 0)

        consistency_loss = self.consistency_loss_per_token
        self.consistency_loss_per_token = None  # Consume the loss

        if not valid_mask.any():
            return torch.tensor(
                0.0, device=labels_shifted.device, dtype=consistency_loss.dtype
            )

        avg_losses = torch.zeros_like(consistency_loss)
        avg_losses[valid_mask] = (
            consistency_loss[valid_mask] / iter_count[valid_mask].float()
        ).to(dtype=consistency_loss.dtype)

        if num_items_in_batch is not None:
            return avg_losses[valid_mask].sum() / num_items_in_batch
        else:
            return avg_losses[valid_mask].mean()

@register_loss_func
@capture_init_args
class WeightedConsistencyLoss(ConsistencyLoss):
    def __init__(self, detach_grad: str = "none", max_iter: Optional[int] = None, **kwargs):
        """
        detach_grad: control gradient flow for weighted consistency loss.
          - "none": allow gradients through both NTP token losses and stop_prob
          - "ntp": detach token losses (train decider only via stop_prob path)
          - "prob": detach stop_prob (train NTP path only, no grad to decider)

        Backward-compat: accepts legacy decider_only_grad in kwargs.
        decider_only_grad=True maps to detach_grad="ntp"; False -> "none".
        """
        # Backward compatibility shim
        if "decider_only_grad" in kwargs:
            legacy = bool(kwargs.pop("decider_only_grad"))
            mapped = "ntp" if legacy else "none"
            logger.warning(
                f"WeightedConsistencyLoss: 'decider_only_grad' is deprecated. Use 'detach_grad' instead. Mapped to '{mapped}'."
            )
            detach_grad = mapped

        super().__init__(**kwargs)

        if not isinstance(detach_grad, str):
            raise TypeError("detach_grad must be one of {'none','ntp','prob'}")
        detach_mode = detach_grad.lower().strip()
        if detach_mode not in ("none", "ntp", "prob"):
            raise ValueError("detach_grad must be one of {'none','ntp','prob'}")
        self.detach_grad = detach_mode
        # Optional explicit max_iter (wired from model config)
        self.max_iter: Optional[int] = int(max_iter) if max_iter is not None else None

    def prepare_loss(self, batch_size, query_len, device, dtype):
        super().prepare_loss(batch_size, query_len, device, dtype)
        self.accumulative_continue_prob = torch.ones(
            batch_size, query_len, device=device, dtype=dtype
        )

    def _update_consistency_loss(self, token_losses: torch.Tensor):
        # Detach token loss path if requested (train decider only)
        if self.detach_grad == "ntp":
            token_losses = token_losses.detach()
        # Multiply by stop probability and accumulative continue prob
        self.consistency_loss_per_token = (
            self.consistency_loss_per_token
            + token_losses * self.accumulative_continue_prob * self.stop_prob
        )

    def intra_iter_loss_func(
        self,
        active_logits: torch.Tensor,
        current_iter_mask: torch.BoolTensor,
        active_labels_shifted: torch.Tensor,
        active_valid_continue_logits: torch.Tensor, # shape (batch_size, valid_length)
        active_valid_mask: torch.LongTensor, # shape (batch_size, active_length)
        iter_depth: int = None,
        **kwargs,
    ):
        self.stop_prob = torch.ones_like(self.accumulative_continue_prob)
        full_valid_mask = torch.zeros_like(current_iter_mask, dtype=torch.bool, device=current_iter_mask.device)
        full_valid_mask[current_iter_mask] = (active_valid_mask == 1).view(-1)
        active_valid_continue_prob = torch.sigmoid(active_valid_continue_logits)
        # Optionally detach prob path to block gradients to decider
        stop_prob_local = (1 - active_valid_continue_prob)
        if self.detach_grad == "prob":
            stop_prob_local = stop_prob_local.detach()
        self.stop_prob[full_valid_mask] = stop_prob_local.view(-1).to(self.stop_prob.dtype)

        # On the final allowed iteration depth, force stop probability to 1.0
        if (iter_depth is not None) and (self.max_iter is not None):
            if int(iter_depth) >= int(self.max_iter):
                self.stop_prob[full_valid_mask] = torch.ones_like(self.stop_prob[full_valid_mask])

        token_losses = super().intra_iter_loss_func(
            active_logits,
            current_iter_mask,
            active_labels_shifted
        )

        self.accumulative_continue_prob = (1 - self.stop_prob) * self.accumulative_continue_prob
        
        return token_losses

    def final_loss_func(
        self,
        labels_shifted: torch.Tensor,
        iter_count: torch.Tensor,
        training: bool,
        **kwargs,
    ) -> torch.Tensor:
        if self.consistency_loss_per_token is None:
            raise RuntimeError(
                "Consistency loss tensor not initialized or already consumed."
            )

        num_items_in_batch = kwargs.get("num_items_in_batch", None)
        valid_mask = (labels_shifted != -100) & (iter_count > 0)
        consistency_loss = self.consistency_loss_per_token
        self.consistency_loss_per_token = None  # Consume the loss

        if not valid_mask.any():
            return torch.tensor(
                0.0, device=labels_shifted.device, dtype=consistency_loss.dtype
            )

        avg_losses = torch.zeros_like(consistency_loss)
        avg_losses[valid_mask] = consistency_loss[valid_mask].to(dtype=consistency_loss.dtype)

        if num_items_in_batch is not None:
            return avg_losses[valid_mask].sum() / num_items_in_batch
        else:
            return avg_losses[valid_mask].mean()

@register_loss_func
@capture_init_args
class IterDeciderLoss(LossFunc):
    """
    Loss function for iter decider that predicts whether each token should continue iterating.
    Uses BCE loss similar to the router training implementation.
    Calculates loss at each iteration depth.
    """
    _is_intra_iter_loss: bool = True

    def __init__(self, pos_weight: Optional[float] = None, skip_last_iter: bool = True, max_iter: Optional[int] = None, **kwargs):
        """
        Initialize IterDeciderLoss.
        
        Args:
            pos_weight: Positive class weight for BCE loss to handle class imbalance
            skip_last_iter: If True, skip loss at the max iteration because it's always stop
        """
        from hr2r.model.recurrent_transformer import HR2RForCausalLM # import like this to avoid circular import
        self.assign_active = HR2RForCausalLM.assign_active
        super().__init__(**kwargs)
        self.pos_weight = pos_weight
        self.skip_last_iter = bool(skip_last_iter)
        # Optional explicit max_iter (preferred over reading from model at call time)
        self.max_iter: Optional[int] = int(max_iter) if max_iter is not None else None

        if self.skip_last_iter and self.max_iter is None:
            raise ValueError("max_iter must be provided if skip_last_iter is True")
        
        # Create BCE loss criterion
        if pos_weight is not None:
            self.criterion = torch.nn.BCEWithLogitsLoss(
                pos_weight=torch.tensor([pos_weight])
            )
        else:
            self.criterion = torch.nn.BCEWithLogitsLoss()
    
    def prepare_loss(self, batch_size, query_len, device, dtype):
        self.iter_decider_loss_per_token = torch.zeros(
            batch_size, query_len, device=device, dtype=torch.float32
        ) # always use float32 for loss
        # Metric accumulators (float32 scalars on device)
        self._metric_correct_count = torch.zeros(1, device=device, dtype=torch.float32)
        self._metric_total_count = torch.zeros(1, device=device, dtype=torch.float32)

    def intra_iter_loss_func(
        self,
        active_logits: torch.Tensor,
        current_iter_mask: torch.BoolTensor,
        active_labels_shifted: torch.Tensor,
        active_valid_continue_logits: Optional[torch.Tensor],
        active_valid_mask: torch.LongTensor,
        iter_depth: int,
        active_iter_count_labels: Optional[torch.LongTensor] = None,
        **kwargs,
    ):
        """
        Calculate iter decider loss at each iteration depth.
        
        Args:
            active_logits: Model logits (not used)
            current_iter_mask: Mask for current iteration
            active_labels_shifted: Shifted labels (not used)
            active_valid_continue_prob: Continue probabilities from iter_decider
            active_valid_continue_logits: Continue logits from iter_decider
            active_valid_mask: Valid mask for active tokens
            iter_depth: Current iteration depth
            active_iter_count_labels: Target iteration counts
            **kwargs: Additional arguments
        """
        if active_iter_count_labels is None or active_valid_continue_logits is None:
            return torch.tensor(0.0, dtype=torch.float32)

        if not current_iter_mask.any() or active_valid_mask.sum() == 0:
            return torch.tensor(0.0, dtype=torch.float32)

        # Optionally skip loss for the final iteration depth (always-stop step)
        if self.skip_last_iter:
            if int(iter_depth) >= int(self.max_iter):
                return torch.tensor(0.0, dtype=torch.float32)
        
        # Update metrics using probabilities derived from logits and threshold
        if active_iter_count_labels is not None and active_valid_continue_logits is not None:
            valid_active_mask = (active_valid_mask == 1)
            valid_iter_count_labels = active_iter_count_labels[valid_active_mask]
            non_padding_mask = (valid_iter_count_labels != -100)
            if non_padding_mask.any():
                final_continue_targets = (valid_iter_count_labels[non_padding_mask] > iter_depth).to(torch.float32)
                final_continue_probs = torch.sigmoid(active_valid_continue_logits[non_padding_mask]).to(torch.float32)
                # Resolve threshold value
                iter_decider_threshold = kwargs.get('iter_decider_threshold', 0.5)
                if isinstance(iter_decider_threshold, torch.Tensor):
                    threshold_value = float(iter_decider_threshold.detach().item())
                else:
                    threshold_value = float(iter_decider_threshold)

                with torch.no_grad():
                    pred_positive = (final_continue_probs > threshold_value).to(torch.float32)
                    target_positive = final_continue_targets
                    correct = (pred_positive == target_positive).to(torch.float32).sum()
                    total = torch.tensor(float(pred_positive.numel()), device=final_continue_probs.device, dtype=torch.float32)

                    if hasattr(self, '_metric_correct_count') and self._metric_correct_count is not None:
                        self._metric_correct_count += correct
                        self._metric_total_count += total
        
        active_valid_continue_logits = active_valid_continue_logits.float()
        
        device = active_valid_continue_logits.device
        dtype = active_valid_continue_logits.dtype
        batch_size, query_len = current_iter_mask.shape
        
        if self.iter_decider_loss_per_token is None:
            raise RuntimeError(
                "Iter decider loss tensor not initialized. Call `prepare_loss` first."
            )

        # Initialize token losses for this iteration
        token_losses = torch.zeros(
            batch_size, query_len, device=device, dtype=dtype
        )

        # Calculate target labels: should continue if iter_count_labels > iter_depth
        # Only consider valid active tokens
        valid_active_mask = (active_valid_mask == 1)
        # For valid active tokens, calculate binary targets
        valid_iter_count_labels = active_iter_count_labels[valid_active_mask]
        valid_continue_targets = (valid_iter_count_labels > iter_depth).float()
        
        # Exclude padding tokens (-100)
        non_padding_mask = (valid_iter_count_labels != -100)
        if not non_padding_mask.any():
            return torch.tensor(0.0, device=device, dtype=dtype)

        final_continue_targets = valid_continue_targets[non_padding_mask]
        final_continue_logits = active_valid_continue_logits[non_padding_mask]

        # Move pos_weight to correct device if needed
        if hasattr(self.criterion, 'pos_weight') and self.criterion.pos_weight is not None:
            self.criterion.pos_weight = self.criterion.pos_weight.to(device=device)

        # Calculate BCE loss
        loss = self.criterion(final_continue_logits.unsqueeze(-1), final_continue_targets.unsqueeze(-1))

        # Assign loss back to full tensor structure 
        # This is simplified - we assign the same loss to all valid active tokens
        if valid_active_mask.any() and non_padding_mask.any():
            # Create a tensor to hold loss for active tokens
            active_token_losses = torch.zeros(batch_size, active_valid_mask.shape[1], device=device, dtype=loss.dtype)
            # We'll assign the average loss to all contributing tokens
            num_contributing_tokens = non_padding_mask.sum()
            if num_contributing_tokens > 0:
                per_token_loss = loss / num_contributing_tokens
                # Create a full-size tensor for valid active positions
                valid_positions = torch.zeros_like(active_token_losses, dtype=torch.bool)
                valid_positions[valid_active_mask] = non_padding_mask
                active_token_losses[valid_positions] = per_token_loss
                self.assign_active(current_iter_mask, active_token_losses, token_losses)

        # Update cumulative loss
        self._update_iter_decider_loss(token_losses)
        
        return token_losses

    def _update_iter_decider_loss(self, token_losses):
        self.iter_decider_loss_per_token = token_losses + self.iter_decider_loss_per_token

    def final_loss_func(
        self,
        logits: torch.Tensor,
        labels_shifted: torch.Tensor,
        iter_count: torch.Tensor,
        iter_count_labels: Optional[torch.Tensor] = None,
        training: bool = True,
        **kwargs,
    ) -> torch.Tensor:
        """
        Calculate final iter decider loss from accumulated losses.
        
        Args:
            logits: Model logits (not used)
            labels_shifted: Shifted labels (not used)
            iter_count: Actual iteration counts from model
            iter_count_labels: Target iteration count labels (optional)
            training: Whether in training mode
            **kwargs: Additional arguments
            
        Returns:
            Accumulated iter decider loss
        """
        if self.iter_decider_loss_per_token is None:
            raise RuntimeError(
                "Iter decider loss tensor not initialized or already consumed."
            )

        num_items_in_batch = kwargs.get("num_items_in_batch", None)
        
        # Use iter_count_labels if available, otherwise fall back to simple validation
        if iter_count_labels is not None:
            valid_mask = (iter_count_labels != -100) & (iter_count > 0)
        else:
            valid_mask = (iter_count > 0)

        iter_decider_loss = self.iter_decider_loss_per_token
        self.iter_decider_loss_per_token = None  # Consume the loss

        if not valid_mask.any():
            return torch.tensor(
                0.0, device=logits.device, dtype=iter_decider_loss.dtype
            )

        # Compute and log metrics if requested
        logger_callback = kwargs.get('logger_callback', None)
        with torch.no_grad():
            if hasattr(self, '_metric_total_count') and self._metric_total_count is not None and (
                (self._metric_total_count.item() > 0) or (kwargs.get('num_items_in_batch', None) is not None)
            ):
                # Accuracy logging: follow avg_iter_count pattern → correct_count / num_items_in_batch
                if logger_callback is not None:
                    if not hasattr(logger_callback, 'iter_decider_accuracy'):
                        logger_callback.iter_decider_accuracy = 0.0

                    num_items_in_batch = kwargs.get('num_items_in_batch', None)
                    if num_items_in_batch is not None and num_items_in_batch > 0:
                        acc_step = (self._metric_correct_count / num_items_in_batch)
                    else:
                        # Fallback to total-based accuracy if num_items_in_batch is absent
                        total_safe = torch.clamp(self._metric_total_count, min=1.0)
                        acc_step = (self._metric_correct_count / total_safe)
                    logger_callback.iter_decider_accuracy += float(acc_step)

        # Reset metric accumulators after consumption
        self._metric_correct_count = None
        self._metric_total_count = None

        # Calculate average loss over valid tokens
        if num_items_in_batch is not None:
            return iter_decider_loss[valid_mask].sum() / num_items_in_batch
        else:
            return iter_decider_loss[valid_mask].mean()


@register_loss_func
@capture_init_args
class CombinedLoss(LossFunc):
    """
    Combined loss function that weighs multiple loss components.
    Supports combining regular loss (e.g., NextTokenPredLoss) with IterDeciderLoss.
    """
    _is_intra_iter_loss: bool = False

    def __init__(
        self, 
        primary_loss: str = "NextTokenPredLoss",
        primary_loss_kwargs: Optional[Dict[str, Any]] = None,
        secondary_loss: str = "IterDeciderLoss", 
        secondary_loss_kwargs: Optional[Dict[str, Any]] = None,
        secondary_loss_weight: float = 1.0,
        **kwargs
    ):
        """
        Initialize CombinedLoss.
        
        Args:
            primary_loss: Name of the primary loss function (e.g., "NextTokenPredLoss")
            primary_loss_kwargs: Kwargs for primary loss function
            secondary_loss: Name of the secondary loss function (e.g., "IterDeciderLoss")
            secondary_loss_kwargs: Kwargs for secondary loss function
            secondary_loss_weight: Weight for the secondary loss component
        """
        super().__init__(**kwargs)
        
        # Initialize primary loss
        primary_loss_cls = get_loss_func_class(primary_loss)
        self.primary_loss = primary_loss_cls(**(primary_loss_kwargs or {}))
        
        # Initialize secondary loss
        secondary_loss_cls = get_loss_func_class(secondary_loss)
        self.secondary_loss = secondary_loss_cls(**(secondary_loss_kwargs or {}))
        
        self.secondary_loss_weight = secondary_loss_weight
        
        # Check if any component has intra_iter_loss
        self._is_intra_iter_loss = (
            self.primary_loss._is_intra_iter_loss or 
            self.secondary_loss._is_intra_iter_loss
        )

    def prepare_loss(self, batch_size, query_len, device, dtype):
        """Prepare loss components."""
        self.primary_loss.prepare_loss(batch_size, query_len, device, dtype)
        self.secondary_loss.prepare_loss(batch_size, query_len, device, dtype)

    def intra_iter_loss_func(self, *args, **kwargs):
        """Handle intra-iteration loss if any component supports it."""
        primary_loss_result = None
        secondary_loss_result = None
        
        if self.primary_loss._is_intra_iter_loss:
            primary_loss_result = self.primary_loss.intra_iter_loss_func(*args, **kwargs)
        
        if self.secondary_loss._is_intra_iter_loss:
            secondary_loss_result = self.secondary_loss.intra_iter_loss_func(*args, **kwargs)
        
        # Return the non-None result, or sum if both exist
        if primary_loss_result is not None and secondary_loss_result is not None:
            return primary_loss_result + self.secondary_loss_weight * secondary_loss_result
        elif primary_loss_result is not None:
            return primary_loss_result
        elif secondary_loss_result is not None:
            return self.secondary_loss_weight * secondary_loss_result
        else:
            # Get device from the first tensor argument
            device = 'cpu'
            if args:
                for arg in args:
                    if hasattr(arg, 'device'):
                        device = arg.device
                        break
            return torch.tensor(0.0, device=device)

    def final_loss_func(
        self,
        logits: torch.Tensor,
        labels_shifted: torch.Tensor,
        iter_count: torch.Tensor,
        training: bool,
        **kwargs,
    ) -> torch.Tensor:
        """
        Calculate combined loss.
        
        Args:
            logits: Model logits
            labels_shifted: Shifted labels
            iter_count: Actual iteration counts
            training: Whether in training mode
            **kwargs: Additional arguments including iter_count_labels
            
        Returns:
            Combined weighted loss
        """
        # Calculate primary loss
        primary_loss = self.primary_loss.final_loss_func(
            logits=logits,
            labels_shifted=labels_shifted,
            iter_count=iter_count,
            training=training,
            **kwargs
        )
        
        # Calculate secondary loss if iter_count_labels is provided
        secondary_loss = torch.tensor(0.0, device=logits.device, dtype=logits.dtype)
        if 'iter_count_labels' in kwargs and kwargs['iter_count_labels'] is not None:
            secondary_loss = self.secondary_loss.final_loss_func(
                logits=logits,
                labels_shifted=labels_shifted,
                iter_count=iter_count,
                training=training,
                **kwargs
            )
        
        # Combine losses
        total_loss = primary_loss + self.secondary_loss_weight * secondary_loss
        
        return total_loss



@register_loss_func
@capture_init_args
class UncertaintyWeightedCombinedLoss(CombinedLoss):
    """
    Combined loss with uncertainty weighting (Kendall & Gal).

    Uses per-component log-variance scalars to weight losses:
      L_total = 0.5 * exp(-2 * log_sigma_p) * L_p + log_sigma_p
              + 0.5 * exp(-2 * log_sigma_s) * L_s + log_sigma_s

    Notes:
    - This class inherits the composition (primary/secondary) from CombinedLoss.
    - The log-variance scalars here are plain tensors (not parameters of the model).
      If you need them to be trainable, wire them into your optimizer explicitly or
      move the loss into a module that exposes parameters.
    """

    _is_intra_iter_loss: bool = False

    def __init__(
        self,
        primary_loss: str = "NextTokenPredLoss",
        primary_loss_kwargs: Optional[Dict[str, Any]] = None,
        secondary_loss: str = "IterDeciderLoss",
        secondary_loss_kwargs: Optional[Dict[str, Any]] = None,
        init_log_sigma_primary: float = 0.0,
        init_log_sigma_secondary: float = 0.0,
        **kwargs,
    ):
        super().__init__(
            primary_loss=primary_loss,
            primary_loss_kwargs=primary_loss_kwargs,
            secondary_loss=secondary_loss,
            secondary_loss_kwargs=secondary_loss_kwargs,
            secondary_loss_weight=1.0,
            **kwargs,
        )
        # Store as float tensors; users can manage learning externally if desired
        self.log_sigma_primary = torch.tensor(float(init_log_sigma_primary))
        self.log_sigma_secondary = torch.tensor(float(init_log_sigma_secondary))

    def final_loss_func(
        self,
        logits: torch.Tensor,
        labels_shifted: torch.Tensor,
        iter_count: torch.Tensor,
        training: bool,
        **kwargs,
    ) -> torch.Tensor:
        # Compute component losses (use same rules as CombinedLoss)
        primary = self.primary_loss.final_loss_func(
            logits=logits,
            labels_shifted=labels_shifted,
            iter_count=iter_count,
            training=training,
            **kwargs,
        ).to(dtype=torch.float32)

        secondary = torch.tensor(0.0, device=logits.device, dtype=torch.float32)
        use_secondary = ("iter_count_labels" in kwargs) and (kwargs["iter_count_labels"] is not None)
        if use_secondary:
            secondary = self.secondary_loss.final_loss_func(
                logits=logits,
                labels_shifted=labels_shifted,
                iter_count=iter_count,
                training=training,
                **kwargs,
            ).to(dtype=torch.float32)

        # Uncertainty weights
        ls_p = self.log_sigma_primary.to(device=logits.device, dtype=torch.float32)
        weighted_primary = 0.5 * torch.exp(-2.0 * ls_p) * primary + ls_p

        total = weighted_primary

        if use_secondary:
            ls_s = self.log_sigma_secondary.to(device=logits.device, dtype=torch.float32)
            weighted_secondary = 0.5 * torch.exp(-2.0 * ls_s) * secondary + ls_s
            total = total + weighted_secondary

        return total.to(dtype=logits.dtype)

@register_loss_func
@capture_init_args
class InterleavedLoss(LossFunc):
    """
    Switch between two loss functions on fixed intervals and freeze/unfreeze
    corresponding model components to enable interleaved training of the main
    LLM and the iter decider.

    Phases:
    - primary phase: optimize main LLM (unfreeze LLM, freeze iter_decider)
    - secondary phase: optimize iter decider (freeze LLM, unfreeze iter_decider)
    """
    _is_intra_iter_loss: bool = False

    def __init__(
        self,
        primary_loss: str = "NextTokenPredLoss",
        primary_loss_kwargs: Optional[Dict[str, Any]] = None,
        secondary_loss: str = "IterDeciderLoss",
        secondary_loss_kwargs: Optional[Dict[str, Any]] = None,
        interval_steps: int = 100,
        primary_interval_steps: Optional[int] = None,
        secondary_interval_steps: Optional[int] = None,
        start_with: str = "primary",
        freeze_control: str = "none",
        **kwargs,
    ):
        super().__init__(**kwargs)

        primary_loss_cls = get_loss_func_class(primary_loss)
        self.primary_loss = primary_loss_cls(**(primary_loss_kwargs or {}))

        secondary_loss_cls = get_loss_func_class(secondary_loss)
        self.secondary_loss = secondary_loss_cls(**(secondary_loss_kwargs or {}))

        # Interval configuration
        self.interval_steps = int(interval_steps)
        self.primary_interval_steps = int(primary_interval_steps) if primary_interval_steps is not None else None
        self.secondary_interval_steps = int(secondary_interval_steps) if secondary_interval_steps is not None else None

        if start_with not in ("primary", "secondary"):
            start_with = "primary"
        self.start_with = start_with

        # Strict string control: one of {none, all, primary, secondary}
        if not isinstance(freeze_control, str):
            raise TypeError("freeze_control must be a string in {'none','all','primary','secondary'}")
        val = freeze_control.lower().strip()
        if val not in ("none", "all", "primary", "secondary"):
            raise ValueError("freeze_control must be one of {'none','all','primary','secondary'}")
        self._freeze_control_mode = val

        # Track step and last-applied phase to avoid redundant toggles
        self._internal_step_counter: int = 0
        self._last_applied_step: Optional[int] = None
        self._last_applied_phase: Optional[str] = None
        # External training state (set by trainer callback)
        self._external_global_step: Optional[int] = None
        self._external_epoch: Optional[int] = None
        # Saved baseline requires_grad states for selective-freeze modes
        self._saved_freeze_states: Dict[str, Any] = {}

        # Intra-iteration support if either sub-loss needs it
        self._is_intra_iter_loss = (
            self.primary_loss._is_intra_iter_loss or self.secondary_loss._is_intra_iter_loss
        )

    def prepare_loss(self, batch_size, query_len, device, dtype):
        # Increment internal counter (used only when global_step is not provided)
        self._internal_step_counter += 1
        # Prepare BOTH losses to avoid phase mismatch when global_step is injected later
        # This keeps intra-iter buffers ready regardless of the active phase
        self.primary_loss.prepare_loss(batch_size, query_len, device, dtype)
        self.secondary_loss.prepare_loss(batch_size, query_len, device, dtype)

    # -------------------- Phase selection helpers --------------------
    def _get_phase(self, global_step: Optional[int] = None) -> str:
        """Return the active phase name for the given step."""
        step = int(global_step) if global_step is not None else self._internal_step_counter

        # Determine intervals
        p_int = self.primary_interval_steps or self.interval_steps
        s_int = self.secondary_interval_steps or self.interval_steps

        cycle = p_int + s_int
        if cycle <= 0:
            return self.start_with

        # Compute position in cycle, considering starting phase
        pos = step % cycle  # step starts from 0
        if self.start_with == "primary":
            return "primary" if pos < p_int else "secondary"
        else:
            return "secondary" if pos < s_int else "primary"

    def _set_requires_grad_for(self, obj, requires_grad: bool):
        if obj is None:
            return
        try:
            for _, param in obj.named_parameters(recurse=True):
                param.requires_grad = requires_grad
        except Exception:
            # Fallback if object is not a nn.Module but exposes parameters()
            if hasattr(obj, "parameters"):
                for param in obj.parameters():
                    try:
                        param.requires_grad = requires_grad
                    except Exception:
                        pass

    def _apply_freeze_policy(self, model, phase: str, step: int):
        if (self._last_applied_step == step) and (self._last_applied_phase == phase):
            return

        mode = getattr(self, "_freeze_control_mode", "none")
        if (mode == 'none') or (model is None):
            # Log phase switch even if no freeze policy is applied
            try:
                self._log_phase_switch(model, phase=phase, step=step, mode=mode)
            except Exception:
                pass
            self._last_applied_step = step
            self._last_applied_phase = phase
            return

        # Components
        simple_model = getattr(model, "simple_base_model", None)
        cascade_model = getattr(model, "cascade_model", None)
        iter_embedding = getattr(model, "iter_embedding", None)
        input_updater = getattr(model, "input_updater", None)
        output_updater = getattr(model, "output_updater", None)
        iter_decider = getattr(model, "iter_decider", None)

        # Helpers to save/restore original requires_grad without forcing to True
        def _gather_params(component):
            if component is None:
                return []
            params = []
            try:
                for _, p in component.named_parameters(recurse=True):
                    params.append(p)
            except Exception:
                if hasattr(component, "parameters"):
                    for p in component.parameters():
                        params.append(p)
            return params

        def _save_and_set(group_name: str, components, requires_grad: bool):
            if not isinstance(components, (list, tuple)):
                components = [components]
            if group_name not in self._saved_freeze_states:
                saved = []
                for comp in components:
                    for p in _gather_params(comp):
                        saved.append((p, bool(p.requires_grad)))
                self._saved_freeze_states[group_name] = saved
            # Apply requested requires_grad to targeted params
            for comp in components:
                self._set_requires_grad_for(comp, requires_grad)

        def _restore(group_name: str):
            saved = self._saved_freeze_states.pop(group_name, None)
            if not saved:
                return
            for p, original in saved:
                try:
                    p.requires_grad = bool(original)
                except Exception:
                    pass

        if mode == 'all':
            if phase == "primary":
                # Train the LLM side
                self._set_requires_grad_for(simple_model, True)
                self._set_requires_grad_for(cascade_model, True)
                self._set_requires_grad_for(iter_embedding, True)
                self._set_requires_grad_for(input_updater, True)
                self._set_requires_grad_for(output_updater, True)
                self._set_requires_grad_for(iter_decider, False)
            else:
                # Train the iter decider side
                self._set_requires_grad_for(simple_model, False)
                self._set_requires_grad_for(cascade_model, False)
                self._set_requires_grad_for(iter_embedding, False)
                self._set_requires_grad_for(input_updater, False)
                self._set_requires_grad_for(output_updater, False)
                self._set_requires_grad_for(iter_decider, True)
        elif mode == 'primary':
            if phase == 'primary':
                # Apply primary-phase freezing only to iter_decider; do not force others True
                _save_and_set("iter_decider", iter_decider, False)
            else:
                # Leaving primary phase: restore iter_decider original state
                _restore("iter_decider")
        elif mode == 'secondary':
            if phase == 'secondary':
                # Apply secondary-phase freezing only to LLM components; do not force decider True
                _save_and_set("llm_group", [simple_model, cascade_model, iter_embedding, input_updater, output_updater], False)
            else:
                # Leaving secondary phase: restore LLM components original states
                _restore("llm_group")

        # After applying policy, log current training state and trainable sizes
        try:
            self._log_phase_switch(model, phase=phase, step=step, mode=mode)
        except Exception:
            pass

        self._last_applied_step = step
        self._last_applied_phase = phase

    def _log_phase_switch(self, model, phase: str, step: int, mode: str):
        """Log phase switch, freeze mode, and trainable parameter statistics."""
        def gather_params(component):
            params = []
            if component is None:
                return params
            try:
                for _, p in component.named_parameters(recurse=True):
                    params.append(p)
            except Exception:
                if hasattr(component, "parameters"):
                    for p in component.parameters():
                        params.append(p)
            return params

        def stats_for(component):
            total = 0
            trainable = 0
            trainable_bytes = 0
            for p in gather_params(component):
                num = int(p.numel())
                total += num
                if getattr(p, 'requires_grad', False):
                    trainable += num
                    try:
                        trainable_bytes += num * int(p.element_size())
                    except Exception:
                        pass
            return total, trainable, trainable_bytes

        # Components
        simple_model = getattr(model, "simple_base_model", None)
        cascade_model = getattr(model, "cascade_model", None)
        iter_embedding = getattr(model, "iter_embedding", None)
        input_updater = getattr(model, "input_updater", None)
        output_updater = getattr(model, "output_updater", None)
        iter_decider = getattr(model, "iter_decider", None)

        comp_list = [
            ("simple_base_model", simple_model),
            ("cascade_model", cascade_model),
            ("iter_embedding", iter_embedding),
            ("input_updater", input_updater),
            ("output_updater", output_updater),
            ("iter_decider", iter_decider),
        ]

        total_trainable = 0
        total_bytes = 0
        lines = []
        for name, comp in comp_list:
            total, trainable, bytes_ = stats_for(comp)
            total_trainable += trainable
            total_bytes += bytes_
            mb = (bytes_ / (1024.0 ** 2)) if bytes_ > 0 else 0.0
            lines.append(
                f"  - {name}: trainable={trainable:,} params ({mb:.2f} MB) / total={total:,}"
            )

        total_mb = (total_bytes / (1024.0 ** 2)) if total_bytes > 0 else 0.0
        header = (
            f"[InterleavedLoss] Phase switch -> phase={phase}, step={int(step)}, freeze_mode={mode}. "
            f"Total trainable params={total_trainable:,} ({total_mb:.2f} MB)"
        )
        logger.info("\n".join([header] + lines))

    # -------------------- Loss plumbing --------------------
    def _ensure_freeze_applied(self, **kwargs):
        model = kwargs.get("model", None)
        # Prefer externally-updated macro step if present
        global_step = kwargs.get("global_step", self._external_global_step)
        active_phase = self._get_phase(global_step)
        step = int(global_step) if global_step is not None else self._internal_step_counter
        self._apply_freeze_policy(model, active_phase, step)
        return active_phase

    # -------------------- External training state update --------------------
    def update_training_state(self, current_step: int, current_epoch: int):
        """Allow trainer/callback to provide macro training state (DDP-safe)."""
        try:
            self._external_global_step = int(current_step)
        except Exception:
            self._external_global_step = None
        try:
            self._external_epoch = int(current_epoch)
        except Exception:
            self._external_epoch = None

    def intra_iter_loss_func(self, *args, **kwargs):
        # Apply freeze policy as early as possible within forward
        active_phase = self._ensure_freeze_applied(**kwargs)

        if active_phase == "primary":
            if self.primary_loss._is_intra_iter_loss:
                return self.primary_loss.intra_iter_loss_func(*args, **kwargs)
            else:
                # Return zero tensor on appropriate device
                device = None
                for arg in args:
                    if hasattr(arg, 'device'):
                        device = arg.device
                        break
                return torch.tensor(0.0, device=device or 'cpu')
        else:
            if self.secondary_loss._is_intra_iter_loss:
                return self.secondary_loss.intra_iter_loss_func(*args, **kwargs)
            else:
                device = None
                for arg in args:
                    if hasattr(arg, 'device'):
                        device = arg.device
                        break
                return torch.tensor(0.0, device=device or 'cpu')

    def final_loss_func(
        self,
        logits: torch.Tensor,
        labels_shifted: torch.Tensor,
        iter_count: torch.Tensor,
        training: bool,
        **kwargs,
    ) -> torch.Tensor:
        # Apply freeze policy for this step before computing final loss
        active_phase = self._ensure_freeze_applied(**kwargs)

        if active_phase == "primary":
            return self.primary_loss.final_loss_func(
                logits=logits,
                labels_shifted=labels_shifted,
                iter_count=iter_count,
                training=training,
                **kwargs,
            )
        else:
            return self.secondary_loss.final_loss_func(
                logits=logits,
                labels_shifted=labels_shifted,
                iter_count=iter_count,
                training=training,
                **kwargs,
            )