import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import hashlib
import math
from typing import Callable, Dict, Optional, Any, Type, Union, Tuple
import pprint

from hr2r.model.hr2r_config import HR2RConfig
from hr2r.model.registry import (
    register_iter_decider,
    get_iter_decider_class,
    capture_init_args,
    mark_wrapper_iter_decider,
)
from hr2r.model.utils import class_string_to_type, type_to_dict_string, apply_noise
from hr2r.train.noise_scheduler import NoiseScheduler

try:
    from r2r.models.router import load_model
except:
    load_model = lambda x : None

POSITIVE_INFINITY_LOGITS = 10.0
MINUS_INFINITY_LOGITS = -10.0
NEUTRAL_LOGITS = 0.0

class IterDecider(nn.Module):
    """Base class for deciding whether to continue iterating a token.

    All IterDecider implementations must efficiently handle inputs of arbitrary shape (..., vocab_size)
    where (...) can be any number of leading dimensions (batch, sequence, etc.).
    """

    def __init__(self, threshold: float = 0.5, max_iter: int = 3):
        super().__init__()
        # store as buffer to allow assignment on subclasses without property conflicts
        self.register_buffer("threshold", torch.tensor(float(threshold), dtype=torch.float32))
        self.max_iter = max_iter

    def forward(self, logits: torch.Tensor, iter_depth: int, **kwargs) -> torch.Tensor:
        """
        Decide whether to continue iterating a token.

        Args:
            logits: The logits of the token, shape (..., vocab_size) where (...)
                   represents arbitrary leading dimensions
            iter_depth: The iteration depth of the token that has been processed.
            Optional kwargs:
                - hidden_states: The hidden states of the token, shape (..., hidden_size) where (...)

        Returns:
            A float tensor of shape (...) with values between 0 and 1,
            indicating the probability of continuing iteration.
            The output preserves all leading dimensions from the input.
        """
        raise NotImplementedError


@register_iter_decider
@capture_init_args
class EntropyIterDecider(IterDecider):
    """Entropy based iteration decider.

    Efficiently handles arbitrary input shapes (..., vocab_size) by computing
    entropy along the vocabulary dimension while preserving all leading dimensions.
    """

    def __init__(self, entropy_threshold: float = 0.5, max_iter: int = 3, iter_ratios: list = None):
        super().__init__(max_iter=max_iter)
        self.entropy_threshold = entropy_threshold
        self.iter_ratios = iter_ratios

    def forward(self, logits: torch.Tensor, iter_depth: int, **kwargs) -> torch.Tensor:
        if iter_depth >= self.max_iter:
            decision = torch.zeros(logits.shape[:-1], dtype=torch.bool, device=logits.device)
            logits_out = torch.full(decision.shape, MINUS_INFINITY_LOGITS, dtype=logits.dtype, device=logits.device)
            return decision, logits_out
        probs = F.softmax(logits, dim=-1)
        entropy = -(probs * torch.log(probs + 1e-10)).sum(dim=-1)
        
        # select high entropy tokens according to iter_ratios
        if self.iter_ratios is not None:
            assert len(self.iter_ratios) == self.max_iter - 1, "iter_ratios must be the same length as max_iter - 1"
            # Fix: Use iter_depth - 1 as index since iter_depth is 1-indexed but array is 0-indexed
            ratio_index = iter_depth - 1
            if ratio_index < len(self.iter_ratios):
                k = int(entropy.size(-1) * self.iter_ratios[ratio_index])
                if k > 0:
                    _, topk_indices = torch.topk(entropy, k=k, dim=-1)
                    # Create a mask for tokens that should continue iterating
                    continue_mask = torch.zeros_like(entropy, dtype=torch.bool, device=entropy.device)
                    continue_mask.scatter_(-1, topk_indices, True)
                    decision = continue_mask
                    logits_out = torch.full(decision.shape, NEUTRAL_LOGITS, dtype=logits.dtype, device=logits.device)
                    return decision, logits_out
                else:
                    decision = torch.zeros_like(entropy, dtype=torch.bool, device=entropy.device)
                    logits_out = torch.full(decision.shape, NEUTRAL_LOGITS, dtype=logits.dtype, device=logits.device)
                    return decision, logits_out
            else:
                decision = torch.zeros_like(entropy, dtype=torch.bool, device=entropy.device)
                logits_out = torch.full(decision.shape, NEUTRAL_LOGITS, dtype=logits.dtype, device=logits.device)
                return decision, logits_out
        else:
            decision = (entropy > self.entropy_threshold)
            logits_out = torch.full(decision.shape, NEUTRAL_LOGITS, dtype=logits.dtype, device=logits.device)
            return decision, logits_out


@register_iter_decider
@capture_init_args
class TrivialIterDecider(IterDecider):
    """Trivial iteration decider that always ends.

    Efficiently handles arbitrary input shapes (..., vocab_size) by returning
    a boolean tensor of shape (...,) filled with False values.
    """

    def __init__(self, max_iter: int = 1):
        super().__init__(max_iter=max_iter)

    def forward(self, logits: torch.Tensor, iter_depth: int, **kwargs) -> torch.Tensor:
        decision = torch.zeros(logits.shape[:-1], dtype=torch.bool, device=logits.device)
        logits_out = torch.full(decision.shape, NEUTRAL_LOGITS, dtype=logits.dtype, device=logits.device)
        return decision, logits_out

@register_iter_decider
@capture_init_args
class DynamicMismatchIterDecider(IterDecider):
    """[DEPRECATED] Use IterLabelDecider + DynamicMismatchIterLabelGenerator instead.

    Mismatch iteration decider that iters when the argmax of the logits is not the same as the label.

    This decider continues iteration when the predicted token (argmax of logits) 
    does not match the ground truth label, allowing the model to refine its prediction, only used for training.
    
    Args:
        max_iter: Maximum number of iterations (default: 3)
        ignore_index: Index to ignore in labels (default: -100)
    """

    def __init__(self, max_iter: int = 3, ignore_index: int = -100):
        super().__init__(max_iter=max_iter)
        self.ignore_index = ignore_index

    def forward(self, logits: torch.Tensor, iter_depth: int, labels_shifted: Optional[torch.Tensor] = None, **kwargs) -> torch.Tensor:
        """
        Decide whether to continue iterating based on prediction-label mismatch.

        Args:
            logits: The logits of the token, shape (..., vocab_size)
            iter_depth: Current iteration depth (1-indexed)
            labels: Ground truth labels, shape (...), optional

        Returns:
            Boolean tensor of shape (...) indicating whether to continue iterating.
            Returns True when prediction doesn't match label and iter_depth < max_iter.
            At mismatch positions, iter_prob controls the probability of continuing iteration.
        """
        # If we've reached max iterations, stop
        if iter_depth >= self.max_iter:
            decision = torch.zeros(logits.shape[:-1], dtype=torch.bool, device=logits.device)
            logits_out = torch.full(decision.shape, MINUS_INFINITY_LOGITS, dtype=logits.dtype, device=logits.device)
            return decision, logits_out
        
        # If no labels provided, cannot determine mismatch - stop iterating
        if labels_shifted is None:
            decision = torch.zeros(logits.shape[:-1], dtype=torch.bool, device=logits.device)
            logits_out = torch.full(decision.shape, NEUTRAL_LOGITS, dtype=logits.dtype, device=logits.device)
            return decision, logits_out
        
        # compute mismatch-based decision
        mismatch_continue = self._compute_mismatch_continue(logits, labels_shifted)
        decision = mismatch_continue
        logits_out = torch.full(decision.shape, NEUTRAL_LOGITS, dtype=logits.dtype, device=logits.device)
        return decision, logits_out
    
    def _compute_mismatch_continue(self, logits: torch.Tensor, labels_shifted: torch.Tensor) -> torch.Tensor:
        """Helper method to compute mismatch-based continue decision."""
        # Handle causal LM shift: logits[i] predicts labels[i+1]
        # We need to compare logits[:-1] with labels[1:] for proper alignment
        if logits.dim() >= 2 and logits.shape[-2] > 1 and labels_shifted.shape[-1] > 1:
            # For sequence inputs, apply shift
            shifted_logits = logits[..., :-1, :]  # Remove last position logits
            shifted_labels = labels_shifted[..., :-1]      # Remove last position because labels are already shifted
            
            # Get predicted tokens (argmax of shifted logits)
            predicted_tokens = torch.argmax(shifted_logits, dim=-1)
            
            # Check for mismatch: continue if predicted != label
            mismatch_mask = (predicted_tokens != shifted_labels)
            
            # Don't iterate on ignore_index labels
            # valid_label_mask = (shifted_labels != self.ignore_index)
            
            # Continue iteration only if there's a mismatch and label is valid
            shifted_continue = mismatch_mask
            
            # Pad the result to match original logits shape by adding False for last position
            continue_iteration = torch.cat([
                shifted_continue, 
                torch.zeros_like(shifted_continue[..., :1])  # Last position never continues
            ], dim=-1)
        else:
            # For single token or edge cases, fall back to direct comparison
            predicted_tokens = torch.argmax(logits, dim=-1)
            mismatch_mask = (predicted_tokens != labels_shifted)
            valid_label_mask = (labels_shifted != self.ignore_index)
            continue_iteration = mismatch_mask & valid_label_mask
        
        return continue_iteration

@register_iter_decider
@capture_init_args
class FixedLabelIterDecider(IterDecider):
    """[DEPRECATED] Use IterLabelDecider instead.

    Fixed-label iteration decider that follows provided iteration-count labels.

    This decider does not use model logits to derive continue decisions. Instead,
    it consumes per-token iteration-count labels (typically derived from dataset
    annotations such as mismatch) and returns a boolean continue mask based on:

        continue = (iter_count_labels > iter_depth)

    Notes:
    - Works for arbitrary leading dimensions (..., vocab) by ignoring the vocab
      dimension of logits and using the shape from labels.
    - Expects `iter_count_labels` in kwargs. Values should be >=1 for valid tokens,
      and -100 for padding positions.
    - `max_iter` caps the maximum number of iterations; when `iter_depth >= max_iter`,
      all tokens stop.
    """

    def __init__(self, max_iter: int = 3, label_type: str = "mismatch"):
        super().__init__(max_iter=max_iter)
        self.label_type = label_type

    def forward(
        self,
        logits: torch.Tensor,
        iter_depth: int,
        iter_count_labels: Optional[torch.Tensor] = None,
        **kwargs,
    ) -> torch.Tensor:
        # Stop if we already reached the cap
        if iter_depth >= self.max_iter:
            decision = torch.zeros(logits.shape[:-1], dtype=torch.bool, device=logits.device)
            logits_out = torch.full(decision.shape, MINUS_INFINITY_LOGITS, dtype=logits.dtype, device=logits.device)
            return decision, logits_out

        if iter_count_labels is None:
            decision = torch.zeros(logits.shape[:-1], dtype=torch.bool, device=logits.device)
            logits_out = torch.full(decision.shape, NEUTRAL_LOGITS, dtype=logits.dtype, device=logits.device)
            return decision, logits_out

        # Align shapes: labels provided are expected to match leading dims of logits
        # Determine valid (non-padding) positions
        valid_mask = (iter_count_labels != -100)

        # A token should continue if target iteration count is greater than current depth
        decision_bool = (iter_count_labels > iter_depth) & valid_mask
        decision = decision_bool
        logits_out = torch.full(decision.shape, NEUTRAL_LOGITS, dtype=logits.dtype, device=logits.device)
        return decision, logits_out


@register_iter_decider
@capture_init_args
class IterLabelDecider(IterDecider):
    """Iteration decider that strictly follows provided iter_count_labels.

    Decision rule: continue if and only if (iter_count_labels > iter_depth) for valid tokens.
    padding/ignored tokens (-100) will always stop.
    """

    def __init__(self, max_iter: int = 3):
        super().__init__(max_iter=max_iter)

    def forward(
        self,
        logits: torch.Tensor,
        iter_depth: int,
        iter_count_labels: Optional[torch.Tensor] = None,
        **kwargs,
    ) -> torch.Tensor:
        if (iter_depth >= self.max_iter):
            decision = torch.zeros(logits.shape[:-1], dtype=torch.bool, device=logits.device)
            logits_out = torch.full(decision.shape, MINUS_INFINITY_LOGITS, dtype=logits.dtype, device=logits.device)
            return decision, logits_out
        if (iter_count_labels is None):
            decision = torch.zeros(logits.shape[:-1], dtype=torch.bool, device=logits.device)
            logits_out = torch.full(decision.shape, NEUTRAL_LOGITS, dtype=logits.dtype, device=logits.device)
            return decision, logits_out

        valid_mask = (iter_count_labels != -100)
        decision_bool = (iter_count_labels > iter_depth) & valid_mask
        decision = decision_bool
        logits_out = torch.full(decision.shape, NEUTRAL_LOGITS, dtype=logits.dtype, device=logits.device)
        return decision, logits_out

class ClassifierBlock(nn.Module):
    """
    A single transformer-style block for the classifier backbone.
    Implements layer normalization, MLP with expansion, and residual connections.
    """
    def __init__(
        self,
        input_dim,
        output_dim,
        expansion_factor=4,
        dropout_rate=0.3
    ):
        super().__init__()
        self.input_dim = input_dim
        self.output_dim = output_dim

        self.layer_norm = nn.LayerNorm(input_dim)

        self.mlp = nn.Sequential(
            nn.Linear(input_dim, input_dim * expansion_factor),
            nn.GELU(),
            nn.Dropout(dropout_rate),
            nn.Linear(input_dim * expansion_factor, output_dim),
            nn.Dropout(dropout_rate),
        )

        self.dim_change = (
            nn.Linear(input_dim, output_dim)
            if input_dim != output_dim
            else nn.Identity()
        )

    def forward(self, x):
        normalized = self.layer_norm(x)
        residual = self.dim_change(x)
        return residual + self.mlp(normalized)


class ClassifierBackbone(nn.Module):
    """
    Backbone architecture for all classifiers.
    Implements transformer-style MLP blocks with residual connections.
    Position embeddings are disabled in this setup.
    """

    def __init__(
        self,
        input_dim,
        output_dim=1,
        hidden_dims=[256, 512, 256],
        expansion_factor=4,
        dropout_rate=0.3,
        use_position_embedding=False,
        max_position_embeddings=1024,
    ):
        super().__init__()
        self.use_position_embedding = use_position_embedding

        self.blocks = nn.ModuleList()

        self.input_projection = nn.Linear(input_dim, hidden_dims[0])

        block_dims = hidden_dims + [hidden_dims[-1]]

        for i in range(len(block_dims) - 1):
            block_input_dim = block_dims[i]
            block_output_dim = block_dims[i + 1]
            self.blocks.append(
                ClassifierBlock(
                    input_dim=block_input_dim,
                    output_dim=block_output_dim,
                    expansion_factor=expansion_factor,
                    dropout_rate=dropout_rate,
                )
            )

        self.output_layer = nn.Sequential(
            nn.LayerNorm(hidden_dims[-1]),
            nn.Linear(hidden_dims[-1], output_dim),
        )

        self._init_weights()

    def _init_weights(self):
        for module in self.modules():
            if isinstance(module, nn.Linear):
                nn.init.kaiming_normal_(module.weight)
                if module.bias is not None:
                    nn.init.constant_(module.bias, 0)

    def apply_position_embedding(self, x, position_ids=None):
        # No position embeddings used in this setup
        return x

    def forward(self, x, position_ids=None):
        x = self.input_projection(x)
        x = self.apply_position_embedding(x, position_ids)
        for block in self.blocks:
            x = block(x)
        return self.output_layer(x)

@register_iter_decider
@capture_init_args
class MLPIterDecider(IterDecider):
    """Classifier-based iteration decider using hidden states and top-k logits."""

    def __init__(
        self,
        topk: int = 100,
        hidden_states_size: int = 1024,
        hidden_states_layer_nums: list = [16,20,24,28], # explicit layer indices to use from all_hidden_states
        hidden_dims: list = [256, 512, 256],
        expansion_factor: int = 4,
        dropout_rate: float = 0.3,
        normalize_input: bool = False,
        threshold: float = 0.5,
        max_iter: int = 3,
        dtype: torch.dtype = torch.bfloat16,
    ):
        super().__init__(max_iter=max_iter)
        self.topk = topk
        self.hidden_states_size = hidden_states_size
        self.hidden_states_layer_nums = list(hidden_states_layer_nums)
        if hasattr(self.__class__, 'threshold'):
            delattr(self, 'threshold')
        self.threshold = nn.Parameter(torch.tensor(threshold, dtype=dtype, requires_grad=True))
        self.max_iter = max_iter

        self.normalize_input = normalize_input
        if self.normalize_input:
            num_selected = max(1, len(self.hidden_states_layer_nums))
            self.layer_norm_hidden_states = nn.LayerNorm(hidden_states_size * num_selected)

        # Project top-k logits to hidden state size
        self.logits_projection = nn.Linear(self.topk, hidden_states_size, dtype=dtype)

        # Combine hidden states and projected logits
        num_selected = max(1, len(self.hidden_states_layer_nums))
        combined_size = hidden_states_size * num_selected + hidden_states_size
        self.combined_projection = nn.Linear(combined_size, hidden_dims[0], dtype=dtype)

        # Backbone MLP stack
        self.backbone = ClassifierBackbone(
            input_dim=hidden_dims[0],
            output_dim=1,
            hidden_dims=hidden_dims,
            expansion_factor=expansion_factor,
            dropout_rate=dropout_rate,
        )

        self.sigmoid = nn.Sigmoid()

    def forward(self, logits: torch.Tensor, iter_depth: int, all_hidden_states: Optional[torch.Tensor] = None, **kwargs) -> torch.Tensor:
        if iter_depth >= self.max_iter:
            decision = torch.zeros(
                logits.shape[:-1], dtype=torch.bool, device=logits.device
            )
            logits_out = torch.full(decision.shape, MINUS_INFINITY_LOGITS, dtype=logits.dtype, device=logits.device)
            return decision, logits_out

        original_shape = logits.shape[:-1]

        # Build hidden features from requested layers without padding
        num_selected = max(1, len(self.hidden_states_layer_nums))
        if all_hidden_states is None:
            # Fallback to zeros if hidden states are unavailable
            hidden_concat = torch.zeros(*original_shape, self.hidden_states_size * num_selected, device=logits.device, dtype=logits.dtype)
        else:
            hs = all_hidden_states
            # Expect shape (..., L, H); if (..., H) provided, treat as single-layer
            if hs.dim() == logits.dim():
                hs = hs.unsqueeze(-2)
            total_layers = hs.size(-2)
            if num_selected == 1 and len(self.hidden_states_layer_nums) == 0:
                indices = [total_layers - 1]
            else:
                indices = self.hidden_states_layer_nums
            index_tensor = torch.as_tensor(indices, device=hs.device, dtype=torch.long)
            if index_tensor.numel() == 0:
                raise ValueError("hidden_states_layer_nums must not be empty")
            if torch.min(index_tensor).item() < 0 or torch.max(index_tensor).item() >= total_layers:
                raise ValueError(f"hidden_states_layer_nums out of range: {indices}, total_layers={total_layers}")
            selected = torch.index_select(hs, dim=-2, index=index_tensor)  # (..., num_selected, H)
            hidden_concat = selected.reshape(*original_shape, selected.size(-2) * self.hidden_states_size)

        # Mirror PluginNeuralIterDecider behavior: apply top-k on logits
        k = min(self.topk, logits.size(-1))
        topk_values, _ = torch.topk(logits, k=k, dim=-1)

        # Optional normalization
        if self.normalize_input:
            hidden_concat = self.layer_norm_hidden_states(hidden_concat)
            topk_values = torch.softmax(topk_values, dim=-1)

        # Project logits and combine
        logits_features = self.logits_projection(topk_values)
        combined_features = torch.cat([hidden_concat, logits_features], dim=-1)
        x = self.combined_projection(combined_features)

        decision_logits = self.backbone(x)
        if decision_logits.dim() == logits.dim():
            decision_logits = decision_logits.squeeze(-1)

        decision_scores = self.sigmoid(decision_logits)
        thr = self.threshold
        if isinstance(thr, torch.Tensor):
            thr = float(thr.detach().item())
        decision_mask = (decision_scores > thr)
        return decision_mask, decision_logits


@register_iter_decider
@capture_init_args
class RandomIterDecider(IterDecider):
    """Random iteration decider based on final probability distribution.

    Uses a probability distribution over final iteration counts set at initialization.
    Makes deterministic random decisions based on logits content to ensure
    consistency across calls for the same token. Minimum iteration depth is 1.
    """

    def __init__(self, final_probs: list = None, max_iter: int = 3, dtype: torch.dtype = torch.bfloat16):
        """
        Initialize the random iteration decider.

        Args:
            final_probs: Probabilities for stopping at each final iteration count (1-indexed).
                        [0.2, 0.3, 0.5] means 20% stop at iter 1, 30% at iter 2, 50% at iter 3.
            max_iter: Maximum iterations (used if final_probs is None)
        """
        super().__init__(max_iter=max_iter)

        if final_probs is None:
            final_probs = [1.0 / max_iter] * max_iter

        # Warning if final_probs exceeds max_iter
        if len(final_probs) > max_iter:
            print(
                f"Warning: final_probs length ({len(final_probs)}) exceeds max_iter ({max_iter})"
            )

        # Normalize and store
        total = sum(final_probs)
        self.final_probs = torch.tensor(
            [p / total for p in final_probs], dtype=dtype
        )
        self.cumulative_probs = torch.cumsum(self.final_probs, dim=0)

        # Precompute conditional stop probabilities P(stop at d | survived < d)
        # For d = 1: cond_stop = final_probs[0]
        # For d > 1: cond_stop = final_probs[d-1] / (1 - cumulative_probs[d-2])
        conditional = []
        for idx, fp in enumerate(self.final_probs):
            if idx == 0:
                conditional.append(fp)
            else:
                remaining = 1.0 - self.cumulative_probs[idx - 1]
                # Protect against numerical issues when remaining is zero
                conditional.append(
                    fp / remaining if remaining > 0 else torch.tensor(1.0)
                )

        # Store as buffer so it moves with model to device
        self.register_buffer(
            "conditional_stop_probs", torch.tensor(conditional, dtype=dtype)
        )

    def forward(self, logits: torch.Tensor, iter_depth: int, **kwargs) -> torch.Tensor:
        """Conditional random stopping based on target final_probs.

        At iteration depth d (1-indexed):
            stop_prob = P(stop at d | survived < d) = final_probs[d-1] / (1 - sum_{i<d} final_probs[i])

        A token continues to next iteration if a freshly-drawn uniform random number is
        greater than stop_prob *and* we have not exceeded ``max_iter``.
        """
        dtype = logits.dtype
        # If we've already reached or exceeded the maximum allowed iterations, stop.
        if iter_depth >= self.max_iter:
            decision = torch.zeros(
                logits.shape[:-1], dtype=torch.bool, device=logits.device
            )
            logits_out = torch.full(decision.shape, MINUS_INFINITY_LOGITS, dtype=logits.dtype, device=logits.device)
            return decision, logits_out

        # Lookup pre-computed conditional stop probability for this iteration depth
        stop_prob = self.conditional_stop_probs[iter_depth - 1].to(logits.device)

        # Draw uniform random numbers for each token (shape matches leading dims)
        random_values = torch.rand(
            logits.shape[:-1], device=logits.device, dtype=dtype
        )

        # Continue if random value is >= stop_prob
        continue_iteration = random_values >= stop_prob
        decision = continue_iteration
        logits_out = torch.full(decision.shape, NEUTRAL_LOGITS, dtype=logits.dtype, device=logits.device)
        return decision, logits_out


@register_iter_decider
@capture_init_args
class PluginNeuralIterDecider(IterDecider):
    """Generic neural iteration decider that *plugs in* an externally
    initialised :class:`torch.nn.Module`, or the __module__.__name__ of the class, like r2r.models.router.HiddenStatesLogitsClassifier

    The supplied ``module`` is expected to output either

    1. a boolean tensor of shape ``(...,)`` that directly encodes the *continue* mask,
       or
    2. a floating point tensor where values greater than ``threshold`` indicate that
       the token should continue iterating.

    The mapping from this class' inputs (``logits`` and ``iter_depth``) to the keyword arguments expected by ``module.forward`` is defined via ``input_mapping``.  For example, if the wrapped module defines

    ```python
    def forward(self, x, depth):
        ...
    ```

    then ``input_mapping`` should be ``{"logits": "x", "iter_depth": "depth"}``.

    Missing keys are ignored, allowing the wrapped module to only consume a subset of the available inputs.
    """

    def __init__(
        self,
        module_cls: Union[Callable[[Dict[str, Any]], nn.Module], str],
        init_kwargs: Dict[str, Any],
        input_mapping: Dict[str, str],
        threshold: float = 0.5,
        max_iter: int = 3,
        topk_logits: Optional[int] = None,
        dtype: Union[torch.dtype, str] = torch.bfloat16,
    ):
        super().__init__(max_iter=max_iter)
        if isinstance(module_cls, str):
            module_cls = class_string_to_type(module_cls)
        self.plugin_module = module_cls(**init_kwargs).to(dtype=dtype)
        self.input_mapping = input_mapping or {}
        if hasattr(self.__class__, 'threshold'):
            delattr(self, 'threshold')
        self.threshold = nn.Parameter(torch.tensor(threshold, dtype=dtype, requires_grad=True))
        self.max_iter = max_iter
        self.topk_logits = topk_logits

    def forward(self, logits: torch.Tensor, iter_depth: int, hidden_states: torch.Tensor = None, **kwargs) -> torch.Tensor:
        """Delegate the decision to the plugged neural ``module``."""

        # Early-exit if maximum allowed iterations has been reached.
        if iter_depth >= self.max_iter:
            decision = torch.zeros(
                logits.shape[:-1], dtype=torch.bool, device=logits.device
            )
            logits_out = torch.full(decision.shape, MINUS_INFINITY_LOGITS, dtype=logits.dtype, device=logits.device)
            return decision, logits_out

        # Build the kwargs dict for the wrapped module according to *input_mapping*.
        module_forward_kwargs = {}
        for key in self.input_mapping:
            module_forward_kwargs[self.input_mapping[key]] = locals()[key]

        # Apply topk transformation to logits
        if self.topk_logits is not None:
            logits, _ = torch.topk(logits, k=self.topk_logits, dim=-1)
            module_forward_kwargs[self.input_mapping["logits"]] = logits

        # Forward pass through the wrapped module.
        output = self.plugin_module(**module_forward_kwargs)

        if not isinstance(output, torch.Tensor):
            raise TypeError(
                "PluginNeuralIterDecider's module must return a torch.Tensor, "
                f"but got {type(output)} instead."
            )
        # logits from plugin (we will neutralize for mid-iter combination)
        logits_out = output
        if logits_out.dim() == logits.dim() and logits_out.shape[-1] == 1:
            logits_out = logits_out.squeeze(-1)
        logits_out = logits_out.to(dtype=logits.dtype)
        # binary decision using threshold
        decision = torch.sigmoid(logits_out)
        thr = self.threshold
        if isinstance(thr, torch.Tensor):
            thr = float(thr.detach().item())
        decision = (decision > thr).float()
        return decision.bool(), logits_out

    def R2R_update_model_and_config(self, R2R_router_path: str, config: HR2RConfig) -> HR2RConfig:
        """Load the R2R router from the given path and record the R2R router configs"""
        
        with torch.no_grad():
            dtype, device = next(self.plugin_module.parameters()).dtype, next(self.plugin_module.parameters()).device
            R2R_router, R2R_router_configs = load_model(model_path=R2R_router_path)
            self.plugin_module = R2R_router.to(dtype=dtype, device=device)
            print("\n[PluginNeuralIterDecider] Updating plugin iter decider configs from saved version.")
            print("[PluginNeuralIterDecider] Original config:")
            pprint.pprint(config.iter_decider_kwargs.get("init_kwargs", {}), indent=2, compact=False)
            config.iter_decider_kwargs["init_kwargs"] = R2R_router_configs['init_args']
            print("[PluginNeuralIterDecider] Updated config:")
            pprint.pprint(config.iter_decider_kwargs.get("init_kwargs", {}), indent=2, compact=False)
            if hasattr(self, "_init_args"):
                self._init_args = R2R_router_configs['init_args']
            return config

@register_iter_decider
@capture_init_args
class SmoothTransitionIterDecider(IterDecider):
    """Smooth transition iteration decider that smoothly transitions from one iter_decider to another.
    
    During training, at each token position, there is a probability w of using the initial decider
    and probability (1-w) of using the final decider. The weight w transitions from 1 to 0
    following the specified scheduler pattern.
    
    Args:
        max_iter: Maximum number of iterations (default: 3)
        initial_iter_decider_cls: Class name or class of the initial iter_decider
        final_iter_decider_cls: Class name or class of the final iter_decider
        initial_iter_decider_kwargs: Keyword arguments for the initial iter_decider
        final_iter_decider_kwargs: Keyword arguments for the final iter_decider
        scheduler_type: Type of scheduling ("linear", "constant", "exponential", "step")
        num_epochs: Total number of epochs for scheduling
        num_grow_steps: Total number of steps for weight to change from 1 to 0
        exp_gamma: Exponential decay factor for exponential scheduler
    """

    def __init__(
        self, 
        max_iter: int = 3, 
        initial_iter_decider_cls: Union[str, Type[IterDecider]] = "TrivialIterDecider",
        final_iter_decider_cls: Union[str, Type[IterDecider]] = "EntropyIterDecider", 
        initial_iter_decider_kwargs: dict = None, 
        final_iter_decider_kwargs: dict = None,
        scheduler_type: str = "linear",
        num_epochs: Optional[int] = None,
        num_grow_steps: int = 1000,
        exp_gamma: float = 2.0,
        continue_logits_strategy: str = "average",
        ):
        """
        Initialize the smooth transition iteration decider.
        """
        super().__init__(max_iter=max_iter)
        self.scheduler_type = scheduler_type.lower()
        self.num_epochs = num_epochs
        self.num_grow_steps = num_grow_steps
        self.exp_gamma = exp_gamma
        # Logits strategy for combining nested deciders: "select" or "add"
        allowed_strategies = {"select", "add", "average"}
        if continue_logits_strategy not in allowed_strategies:
            raise ValueError(f"continue_logits_strategy must be one of {allowed_strategies}, got {continue_logits_strategy}")
        self.continue_logits_strategy = continue_logits_strategy
        
        # Initialize keyword arguments with defaults
        if initial_iter_decider_kwargs is None:
            initial_iter_decider_kwargs = {}
        if final_iter_decider_kwargs is None:
            final_iter_decider_kwargs = {}
            
        # Ensure max_iter is consistent across both deciders
        initial_iter_decider_kwargs.setdefault('max_iter', max_iter)
        final_iter_decider_kwargs.setdefault('max_iter', max_iter)
        
        # Get decider classes and create instances
        if isinstance(initial_iter_decider_cls, str):
            initial_iter_decider_cls = get_iter_decider_class(initial_iter_decider_cls)
        if isinstance(final_iter_decider_cls, str):
            final_iter_decider_cls = get_iter_decider_class(final_iter_decider_cls)
            
        self.initial_iter_decider = initial_iter_decider_cls(**initial_iter_decider_kwargs)
        self.final_iter_decider = final_iter_decider_cls(**final_iter_decider_kwargs)
        
        # Training state
        self.current_step = 0
        self.current_epoch = 0
        self.transition_weight = 0.0
        
        # Validate scheduler type
        valid_types = ["linear", "constant", "exponential", "step"]
        if self.scheduler_type not in valid_types:
            raise ValueError(f"scheduler_type must be one of {valid_types}, got {self.scheduler_type}")
    
    def get_transition_weight(self) -> float:
        """
        Get current transition weight w (probability of using initial decider).
        Returns value between 0 and 1, where 1 means use initial decider fully,
        0 means use final decider fully.
        """
        if self.scheduler_type == "constant":
            return 0.5  # Always use 50/50 mix
        
        elif self.scheduler_type == "linear":
            # Linear decrease from 1 to 0
            progress = min(self.current_step / self.num_grow_steps, 1.0)
            return 1.0 - progress
        
        elif self.scheduler_type == "exponential":
            # Exponential decay: weight = exp(-gamma * progress)
            progress = min(self.current_step / self.num_grow_steps, 1.0)
            return math.exp(-self.exp_gamma * progress)
        
        elif self.scheduler_type == "step":
            if self.num_epochs is None:
                raise ValueError("num_epochs must be provided for step scheduler")
            
            # Step-wise decrease: weight decreases linearly with epoch
            epoch_progress = min(self.current_epoch / max(self.num_epochs - 1, 1), 1.0)
            return 1.0 - epoch_progress
        
        else:
            return 0.5  # Default fallback
    
    def update_training_state(self, current_step: int, current_epoch: int):
        """Update the current training state for weight scheduling."""
        self.current_step = current_step
        self.current_epoch = current_epoch
        self.transition_weight = self.get_transition_weight()
    
    def forward(self, logits: torch.Tensor, iter_depth: int, **kwargs) -> torch.Tensor:
        """
        Decide whether to continue iterating using smooth transition between two deciders.

        Args:
            logits: The logits of the token, shape (..., vocab_size)
            iter_depth: The iteration depth of the token that has been processed.
            **kwargs: Additional arguments passed to both deciders

        Returns:
            A float tensor of shape (...) with values between 0 and 1,
            indicating the probability of continuing iteration.
        """
        # If we've reached max iterations, stop
        if iter_depth >= self.max_iter:
            decision = torch.zeros(logits.shape[:-1], dtype=torch.bool, device=logits.device)
            logits_out = torch.full(decision.shape, MINUS_INFINITY_LOGITS, dtype=logits.dtype, device=logits.device)
            return decision, logits_out
        
        # Get current transition weight
        w = self.get_transition_weight()
        # print(f"Transition weight: {w}")
        
        # Get decisions from both deciders
        initial_output = self.initial_iter_decider(logits, iter_depth, **kwargs)
        final_output = self.final_iter_decider(logits, iter_depth, **kwargs)

        # Strictly require (decision_bool, logits) tuples from nested deciders
        if (not isinstance(initial_output, tuple)) or (len(initial_output) < 2):
            raise TypeError("Initial iter decider must return a (decision_bool, logits) tuple")
        if (not isinstance(final_output, tuple)) or (len(final_output) < 2):
            raise TypeError("Final iter decider must return a (decision_bool, logits) tuple")

        initial_decision, initial_logits = initial_output[0], initial_output[1]
        final_decision, final_logits = final_output[0], final_output[1]

        # Create random mask for token-level selection
        random_mask = torch.rand(logits.shape[:-1], device=logits.device, dtype=torch.float32)
        use_initial = random_mask < w
        
        # Select decision based on random mask
        decision = torch.where(use_initial, initial_decision, final_decision)
        # Select logits token-wise when available; if one side missing, use the other; support strategies
        if self.continue_logits_strategy == "add":
            logits_sel = initial_logits.to(dtype=logits.dtype) + final_logits.to(dtype=logits.dtype)
        elif self.continue_logits_strategy == "select":
            logits_sel = torch.where(use_initial, initial_logits.to(dtype=logits.dtype), final_logits.to(dtype=logits.dtype))
        elif self.continue_logits_strategy == "average":
            # z = logit(sigmoid(x) * w + sigmoid(y) * (1 - w)) with broadcastable w
            x = initial_logits.to(dtype=logits.dtype)
            y = final_logits.to(dtype=logits.dtype)
            # ensure w has same shape as decision/logits without last dim
            w_tensor = torch.as_tensor(w, device=logits.device, dtype=logits.dtype)
            while w_tensor.dim() < x.dim():
                w_tensor = w_tensor.unsqueeze(-1)
            p = torch.sigmoid(x) * w_tensor + torch.sigmoid(y) * (1.0 - w_tensor)
            logits_sel = torch.logit(p, eps=torch.finfo(p.dtype).eps)
        else:
            raise ValueError(f"Invalid continue_logits_strategy: {self.continue_logits_strategy}")
        return decision, logits_sel


@register_iter_decider
@capture_init_args
class InterleavedIterDecider(IterDecider):
    """Switch between two iter deciders on fixed intervals to match InterleavedLoss phases.

    Phases:
    - primary phase: use `primary_iter_decider`
    - secondary phase: use `secondary_iter_decider`

    Configuration mirrors hr2r.model.loss.InterleavedLoss so both can be kept in sync by
    the existing LoggerCallback via update_training_state hooks.
    """

    def __init__(
        self,
        max_iter: int = 3,
        primary_iter_decider: str = "IterLabelDecider",
        primary_iter_decider_kwargs: Optional[dict] = None,
        secondary_iter_decider: str = "EntropyIterDecider",
        secondary_iter_decider_kwargs: Optional[dict] = None,
        interval_steps: int = 100,
        primary_interval_steps: Optional[int] = None,
        secondary_interval_steps: Optional[int] = None,
        start_with: str = "primary",
        eval_use: str = "secondary",
    ):
        super().__init__(max_iter=max_iter)

        # Instantiate nested deciders
        primary_iter_decider_kwargs = dict(primary_iter_decider_kwargs or {})
        secondary_iter_decider_kwargs = dict(secondary_iter_decider_kwargs or {})
        primary_iter_decider_kwargs.setdefault("max_iter", max_iter)
        secondary_iter_decider_kwargs.setdefault("max_iter", max_iter)

        if not isinstance(primary_iter_decider, str):
            raise ValueError("primary_iter_decider must be a string class name")
        if not isinstance(secondary_iter_decider, str):
            raise ValueError("secondary_iter_decider must be a string class name or a resolvable object path")

        primary_cls = get_iter_decider_class(primary_iter_decider)

        self.primary_iter_decider = primary_cls(**primary_iter_decider_kwargs)

        # Resolve hierarchical references in secondary kwargs, e.g. "...primary" or "self.primary"
        def _get_attr_by_path(root_obj, path: str):
            cur = root_obj
            if not path:
                return cur
            for seg in path.split('.'):
                if not seg:
                    continue
                if not hasattr(cur, seg):
                    raise AttributeError(f"Cannot resolve reference segment '{seg}' in path '{path}'")
                cur = getattr(cur, seg)
            return cur

        def _normalize_primary_alias(path: str) -> str:
            # Keep local names explicit; no remapping needed
            return path

        def _resolve_object_like_string(value):
            if not isinstance(value, str):
                return None
            if value == "primary_iter_decider":
                return self.primary_iter_decider
            if value.startswith("self."):
                path = _normalize_primary_alias(value[len("self."):])
                return _get_attr_by_path(self, path)
            if value.startswith("..."):
                # Interpret as relative to this builder (InterleavedIterDecider)
                rel = value[len("..."):].lstrip('.')
                rel = _normalize_primary_alias(rel)
                return _get_attr_by_path(self, rel)
            return None

        def _cls_name(obj) -> str:
            return obj.__class__.__name__

        # First, try to resolve the secondary_iter_decider itself as an object path
        _secondary_ref = _resolve_object_like_string(secondary_iter_decider)
        if isinstance(_secondary_ref, IterDecider):
            # Direct instance reuse for weight sharing
            self.secondary_iter_decider = _secondary_ref
        else:
            # Fall back to class-based construction with enhanced kwargs resolution/sharing
            secondary_cls = get_iter_decider_class(secondary_iter_decider)
            kwargs_copy = dict(secondary_iter_decider_kwargs or {})
            # Record any resolved object references so we can share the instance after construction
            _resolved_instance_map: Dict[str, IterDecider] = {}
            for k, v in list(kwargs_copy.items()):
                ref_obj = _resolve_object_like_string(v)
                if ref_obj is not None and isinstance(ref_obj, IterDecider):
                    # Keep args as strings where possible for easier saving;
                    # convert object-like string into class name + kwargs captured from the instance.
                    if k.endswith("_cls"):
                        base_name = k[:-4]
                        _resolved_instance_map[base_name] = ref_obj
                        kwargs_copy[k] = _cls_name(ref_obj)
                        sibling_key = base_name + "_kwargs"
                        if sibling_key not in kwargs_copy or not kwargs_copy[sibling_key]:
                            kwargs_copy[sibling_key] = getattr(ref_obj, "_init_args", {})
                    else:
                        # If key is not *_cls but points to a decider, try best-effort pair convention
                        # e.g., base_iter_decider -> base_iter_decider_cls/_kwargs
                        base_name = k
                        _resolved_instance_map[base_name] = ref_obj
                        kwargs_copy[k + "_cls"] = _cls_name(ref_obj)
                        kwargs_copy[k + "_kwargs"] = getattr(ref_obj, "_init_args", {})
                        # Remove original ambiguous key to avoid constructor conflicts
                        del kwargs_copy[k]
                else:
                    # leave as-is (string stays string for saveability)
                    pass
            self.secondary_iter_decider = secondary_cls(**kwargs_copy)
            # Training-time instance sharing: if a kwarg referenced an existing IterDecider instance,
            # replace the constructed attribute with the resolved instance to share weights.
            if isinstance(self.secondary_iter_decider, IterDecider) and _resolved_instance_map:
                for base_name, instance in _resolved_instance_map.items():
                    try:
                        if hasattr(self.secondary_iter_decider, base_name):
                            setattr(self.secondary_iter_decider, base_name, instance)
                    except Exception:
                        pass

        # Interval configuration
        self.interval_steps = int(interval_steps)
        self.primary_interval_steps = int(primary_interval_steps) if primary_interval_steps is not None else None
        self.secondary_interval_steps = int(secondary_interval_steps) if secondary_interval_steps is not None else None
        self.start_with = start_with if start_with in ("primary", "secondary") else "primary"
        self.eval_use = eval_use if eval_use in ("primary", "secondary") else "secondary"

        # Training state
        self._internal_step_counter: int = 0
        self._external_global_step: Optional[int] = None
        self._external_epoch: Optional[int] = None

    # ---- Scheduling helpers (mirrors InterleavedLoss) ----
    def _get_phase(self, global_step: Optional[int] = None, training: bool = True) -> str:
        if not training:
            return self.eval_use

        step = int(global_step) if global_step is not None else self._internal_step_counter
        p_int = self.primary_interval_steps or self.interval_steps
        s_int = self.secondary_interval_steps or self.interval_steps
        cycle = p_int + s_int
        if cycle <= 0:
            return self.start_with
        pos = (max(step, 1) - 1) % cycle
        if self.start_with == "primary":
            return "primary" if pos < p_int else "secondary"
        else:
            return "secondary" if pos < s_int else "primary"

    def update_training_state(self, current_step: int, current_epoch: int):
        try:
            self._external_global_step = int(current_step)
        except Exception:
            self._external_global_step = None
        try:
            self._external_epoch = int(current_epoch)
        except Exception:
            self._external_epoch = None

        # Propagate training state to nested deciders if they expose the hook
        for d in (self.primary_iter_decider, self.secondary_iter_decider):
            if hasattr(d, 'update_training_state') and callable(d.update_training_state):
                try:
                    d.update_training_state(current_step=current_step, current_epoch=current_epoch)
                except Exception:
                    pass

    # ---- Forward ----
    def forward(self, logits: torch.Tensor, iter_depth: int, **kwargs) -> torch.Tensor:
        if iter_depth >= self.max_iter:
            decision = torch.zeros(logits.shape[:-1], dtype=torch.bool, device=logits.device)
            logits_out = torch.full(decision.shape, MINUS_INFINITY_LOGITS, dtype=logits.dtype, device=logits.device)
            return decision, logits_out

        # Heuristic increment for internal counter when no external state is provided
        if self._external_global_step is None:
            self._internal_step_counter += 1

        active_phase = self._get_phase(global_step=self._external_global_step, training=self.training if hasattr(self, 'training') else True)
        decider = self.primary_iter_decider if active_phase == "primary" else self.secondary_iter_decider
        # Update exposed threshold to match the active nested decider for finishing rule
        nested_thr = getattr(decider, 'threshold', 0.5)
        try:
            if isinstance(nested_thr, torch.Tensor):
                nested_thr = float(nested_thr.detach().item())
        except Exception:
            pass
        # Update buffer in-place to avoid re-binding the attribute
        try:
            with torch.no_grad():
                self.threshold.fill_(float(nested_thr))
        except Exception:
            # Last-resort: copy from a freshly created tensor on the same device/dtype
            self.threshold.copy_(torch.tensor(float(nested_thr), dtype=self.threshold.dtype, device=self.threshold.device))
        return decider(logits, iter_depth, **kwargs)


@register_iter_decider
@capture_init_args
class AlwaysWrapperIterDecider(IterDecider):
    """Wrapper that enforces a simple control-flow policy around a base iter decider.

    Modes:
    - "continue": force continuing until the final allowed iteration (previous behavior)
    - "stop": stop after the first iteration

    Finishing rule via threshold (used by HR2R's finished_mask = (prob <= threshold)):
    - continue: threshold = -1.0 until last iteration, then 1.0 to finish all
    - stop: threshold = 1.0 at the first iteration so all tokens finish immediately
    """

    def __init__(
        self,
        max_iter: int = 3,
        base_iter_decider_cls: str = "EntropyIterDecider",
        base_iter_decider_kwargs: Optional[dict] = None,
        mode: str = "continue",
    ):
        super().__init__(max_iter=max_iter)
        if not isinstance(base_iter_decider_cls, str):
            raise ValueError("AlwaysWrapperIterDecider expects base_iter_decider_cls as a string class name")
        mode = str(mode).lower().strip()
        if mode not in ("continue", "stop"):
            raise ValueError("AlwaysWrapperIterDecider mode must be either 'continue' or 'stop'")
        self.mode = mode

        base_cls = get_iter_decider_class(base_iter_decider_cls)
        base_iter_decider_kwargs = dict(base_iter_decider_kwargs or {})
        base_iter_decider_kwargs.setdefault("max_iter", max_iter)
        self.base_iter_decider = base_cls(**base_iter_decider_kwargs)
        self._last_forward_iter_depth: Optional[int] = None

    def update_training_state(self, current_step: int, current_epoch: int):
        if getattr(self, 'base_iter_decider', None) is not None and hasattr(self.base_iter_decider, 'update_training_state') and callable(self.base_iter_decider.update_training_state):
            try:
                self.base_iter_decider.update_training_state(current_step=current_step, current_epoch=current_epoch)
            except Exception:
                pass

    def forward(self, logits: torch.Tensor, iter_depth: int, **kwargs) -> torch.Tensor:
        # Respect cap for shape consistency
        if iter_depth >= self.max_iter:
            decision = torch.zeros(logits.shape[:-1], dtype=torch.bool, device=logits.device)
            logits_out = torch.full(decision.shape, MINUS_INFINITY_LOGITS, dtype=logits.dtype, device=logits.device)
            return decision, logits_out

        self._last_forward_iter_depth = int(iter_depth)

        # Delegate to base decider to obtain its logits if any
        base_output = self.base_iter_decider(logits, iter_depth, **kwargs)
        if (not isinstance(base_output, tuple)) or (len(base_output) < 2):
            raise TypeError("Base iter decider must return a (decision_bool, logits) tuple")
        base_decision, base_logits = base_output[0], base_output[1]

        if self.mode == "continue":
            decision = torch.ones_like(base_decision, dtype=torch.bool, device=base_decision.device)
        else:  # stop after first iteration
            decision = torch.zeros_like(base_decision, dtype=torch.bool, device=base_decision.device)
        return decision, (base_logits.to(dtype=logits.dtype) if base_logits is not None else torch.full_like(decision, NEUTRAL_LOGITS, dtype=logits.dtype))

    @property
    def threshold(self) -> float:
        # If last forward depth is unknown, fall back to base threshold or 0.5
        if self._last_forward_iter_depth is None:
            thr = getattr(self.base_iter_decider, 'threshold', None)
            try:
                if isinstance(thr, torch.Tensor):
                    return float(thr.detach().item())
            except Exception:
                pass
            return float(thr) if thr is not None else 0.5

        if self.mode == "continue":
            # Continue until final iteration
            if self._last_forward_iter_depth < self.max_iter:
                return -1.0
            else:
                return 1.0
        else:
            # Stop mode: finish immediately at the first iteration
            return 1.0

@register_iter_decider
@capture_init_args
class AlwaysContinueWrapperIterDecider(AlwaysWrapperIterDecider):
    """Backward-compatible alias for AlwaysWrapperIterDecider in 'continue' mode."""

    def __init__(
        self,
        max_iter: int = 3,
        base_iter_decider_cls: str = "EntropyIterDecider",
        base_iter_decider_kwargs: Optional[dict] = None,
    ):
        super().__init__(
            max_iter=max_iter,
            base_iter_decider_cls=base_iter_decider_cls,
            base_iter_decider_kwargs=base_iter_decider_kwargs,
            mode="continue",
        )

@register_iter_decider
@capture_init_args
class NoisyWrapperIterDecider(IterDecider):
    """Wrapper iter_decider that injects training-scheduled noise into the base decider's decision.

    This class composes a base iter_decider and a NoiseScheduler. It decouples noise scheduling
    from the Trainer by owning the scheduler internally and exposing the same scheduling hooks
    as other deciders (num_grow_steps/num_epochs + update_training_state).

    Behavior:
    - Calls the base iter_decider to obtain a decision tensor
    - Applies token-level noise via apply_noise using scheduler-controlled probabilities
    - Returns the noised decision alongside the base auxiliary output
    """

    def __init__(
        self,
        max_iter: int = 3,
        base_iter_decider_cls: Union[str, Type[IterDecider]] = "DynamicMismatchIterDecider",
        base_iter_decider_kwargs: Optional[dict] = None,
        noise_scheduler_kwargs: Optional[dict] = None,
        use_token_noise: bool = True,
    ):
        super().__init__(max_iter=max_iter)
        # Instantiate base iter decider
        if isinstance(base_iter_decider_cls, str):
            base_iter_decider_cls = get_iter_decider_class(base_iter_decider_cls)
        if base_iter_decider_kwargs is None:
            base_iter_decider_kwargs = {}
        base_iter_decider_kwargs.setdefault("max_iter", max_iter)
        self.base_iter_decider = base_iter_decider_cls(**base_iter_decider_kwargs)

        # Instantiate internal noise scheduler
        if noise_scheduler_kwargs is None:
            noise_scheduler_kwargs = {}
        self.noise_scheduler = NoiseScheduler(
            scheduler_type=noise_scheduler_kwargs.get("scheduler_type", "constant"),
            max_important_noise=noise_scheduler_kwargs.get("max_important_noise", 0.1),
            max_normal_noise=noise_scheduler_kwargs.get("max_normal_noise", 0.025),
            num_epochs=noise_scheduler_kwargs.get("num_epochs", None),
            num_grow_steps=noise_scheduler_kwargs.get("num_grow_steps", 0),
            exp_gamma=noise_scheduler_kwargs.get("exp_gamma", 2.0),
        )

        # Expose scheduling meta that LoggerCallback can align; keep in sync with scheduler
        self._num_grow_steps = self.noise_scheduler.num_grow_steps
        self._num_epochs = self.noise_scheduler.num_epochs

        # Control whether to apply noise
        self.use_token_noise = use_token_noise

    @property
    def num_grow_steps(self) -> int:
        return self._num_grow_steps

    @num_grow_steps.setter
    def num_grow_steps(self, value: Optional[int]):
        self._num_grow_steps = int(value) if value is not None else 0
        self.noise_scheduler.num_grow_steps = self._num_grow_steps
        # Best-effort propagate to base if it exposes the same meta
        if hasattr(self.base_iter_decider, "num_grow_steps"):
            try:
                self.base_iter_decider.num_grow_steps = self._num_grow_steps
            except Exception:
                pass

    @property
    def num_epochs(self) -> Optional[int]:
        return self._num_epochs

    @num_epochs.setter
    def num_epochs(self, value: Optional[int]):
        self._num_epochs = int(value) if value is not None else 0
        self.noise_scheduler.num_epochs = self._num_epochs
        if hasattr(self.base_iter_decider, "num_epochs"):
            try:
                self.base_iter_decider.num_epochs = self._num_epochs
            except Exception:
                pass

    def update_training_state(self, current_step: int, current_epoch: int):
        """Update internal scheduler and forward to base decider if supported."""
        self.noise_scheduler.update_training_state(current_step=current_step, current_epoch=current_epoch)
        if hasattr(self.base_iter_decider, "update_training_state") and callable(self.base_iter_decider.update_training_state):
            try:
                self.base_iter_decider.update_training_state(current_step=current_step, current_epoch=current_epoch)
            except Exception:
                pass

    def get_current_noise_values(self) -> Tuple[float, float]:
        """Return current (important_noise, normal_noise)."""
        return self.noise_scheduler.get_current_noise_values()

    def forward(self, logits: torch.Tensor, iter_depth: int, **kwargs) -> torch.Tensor:
        # Respect maximum iteration cap
        if iter_depth >= self.max_iter:
            decision = torch.zeros(logits.shape[:-1], dtype=torch.bool, device=logits.device)
            logits_out = torch.full(decision.shape, MINUS_INFINITY_LOGITS, dtype=logits.dtype, device=logits.device)
            return decision, logits_out

        # Avoid double noise by disabling it for base if supported via kwargs
        base_kwargs = dict(kwargs)
        base_kwargs["use_token_noise"] = False
        base_output = self.base_iter_decider(logits, iter_depth, **base_kwargs)
        if (not isinstance(base_output, tuple)) or (len(base_output) < 2):
            raise TypeError("Base iter decider must return a (decision_bool, logits) tuple")
        base_decision, base_logits = base_output[0], base_output[1]

        if self.use_token_noise:
            important_noise, normal_noise = self.noise_scheduler.get_current_noise_values()
            noised_decision = apply_noise(
                decision_mask=base_decision,
                important_token_noise=important_noise,
                normal_token_noise=normal_noise,
            ).to(dtype=torch.float32)
            return noised_decision.bool(), (base_logits.to(dtype=logits.dtype) if base_logits is not None else torch.full_like(noised_decision, NEUTRAL_LOGITS, dtype=logits.dtype))
        else:
            return base_decision.bool(), (base_logits.to(dtype=logits.dtype) if base_logits is not None else torch.full_like(base_decision, NEUTRAL_LOGITS, dtype=logits.dtype))

def save_iter_decider(iter_decider: IterDecider, save_directory: str):
    """Save iter_decider state dict and configuration."""
    # Use captured initialization arguments from the decorator
    init_args = getattr(iter_decider, "_init_args", {})

    # Use natural state_dict - no overrides needed
    state_dict = iter_decider.state_dict()
    state_dict = {k: v.cpu() for k, v in state_dict.items()}
    data = {
        "class": iter_decider.__class__.__name__,
        "state_dict": state_dict,
        "init_args": init_args,
    }

    save_path = os.path.join(save_directory, "iter_decider.bin")
    print(f"Saving iter_decider with {len(state_dict)} parameters to {save_path}")
    torch.save(data, save_path)


def load_iter_decider(load_directory: str, class_name: Optional[str] = None, init_args: Optional[dict] = None) -> IterDecider:
    """Load iter_decider from directory."""
    path = os.path.join(load_directory, "iter_decider.bin")

    if not os.path.isfile(path):
        raise FileNotFoundError(f"No iter_decider found at {path}")

    data = torch.load(path, map_location="cpu", weights_only=False)
    if class_name is None:
        class_name = data.get("class")

    if not class_name:
        raise ValueError("No iter_decider class specified in saved data")

    # Get constructor arguments if available
    if init_args is None:
        init_args = data.get("init_args", {})

    # Create iter_decider instance using registry with proper arguments
    iter_decider_class = get_iter_decider_class(class_name)
    iter_decider = iter_decider_class(**init_args)

    # Load state dict if available - natural loading
    state_dict = data.get("state_dict", {})
    if state_dict:
        # Filter out state_dict keys that conflict with init_args
        filtered_state_dict = {}
        for key, value in state_dict.items():
            if key not in init_args:
                filtered_state_dict[key] = value
            else:
                print(f"Skipping state_dict key '{key}' as it conflicts with init_args")
        print(f"Loading iter_decider state dict with {len(filtered_state_dict)} parameters (filtered from {len(state_dict)})")
        if filtered_state_dict:
            # print(filtered_state_dict.values())
            iter_decider.load_state_dict(filtered_state_dict, strict=False)

    return iter_decider
