import os
import copy
from abc import ABC, abstractmethod
from typing import Optional, TYPE_CHECKING

import torch
import torch.nn as nn

if TYPE_CHECKING:
    from hr2r.model.recurrent_transformer import HR2RForCausalLM

from hr2r.model.registry import register_input_updater, get_input_updater_class, capture_init_args


class InputUpdater(nn.Module, ABC):
    """
    Base class for updating input embeddings between iterations.
    
    This class is designed to efficiently handle tensors of arbitrary shape (..., x),
    where the leading dimensions can be any combination of batch, sequence, or other
    dimensions. All operations preserve the leading dimensions and only operate on
    the last dimension for vocabulary/embedding operations.
    """

    @abstractmethod
    def forward(
        self,
        logits: torch.Tensor,
        prev_inputs: torch.Tensor,
        embedding_weight: torch.Tensor,
        hidden_states: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        """
        Return updated inputs for the next iteration.
        
        This method efficiently handles tensors of arbitrary shape, preserving all
        leading dimensions while operating only on the embedding dimension.
        
        Args:
            logits: The logits from the token, shape (..., vocab_size)
            prev_inputs: The previous inputs, shape (..., embed_dim)
            embedding_weight: The embedding weight tensor, shape (vocab_size, embed_dim)
            hidden_states: The hidden states, shape (..., hidden_dim)

        Returns:
            The updated inputs, shape (..., embed_dim)
            
        Note:
            All leading dimensions (...) are preserved exactly. The implementation
            is optimized for efficient processing regardless of the number or size
            of leading dimensions (e.g., batch size, sequence length, etc.).
        """

@register_input_updater
@capture_init_args
class AdditiveUpdater(InputUpdater):
    """Additive update using logits to generate embeddings."""

    def __init__(self, topk: Optional[int] = None, learnable_weight: bool = True, scalar_temperature: float = 1.0, dtype: torch.dtype = torch.bfloat16):
        super().__init__()
        self.topk = topk
        self.learnable_weight = learnable_weight
        self.scalar_temperature = scalar_temperature
        
        # Initialize learnable weight parameter if enabled
        if self.learnable_weight:
            self.weight = nn.Parameter(torch.zeros(1, dtype=dtype))
        else:
            self.register_buffer('weight', torch.zeros(1, dtype=dtype))

    def forward(
        self,
        logits: torch.Tensor,
        prev_inputs: torch.Tensor,
        embedding_weight: torch.Tensor,
        hidden_states: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        # embedding_weight is provided directly
        dtype = logits.dtype
        
        # Apply topk filtering if specified
        if self.topk is not None:
            # Get topk values and indices
            topk_values, topk_indices = torch.topk(logits, k=min(self.topk, logits.size(-1)), dim=-1)
            
            # Only compute softmax on topk values to reduce computation
            topk_probs = torch.softmax(topk_values, dim=-1)  # (..., topk), may automatically convert to float32
            
            # Select corresponding embeddings from embedding layer
            # topk_indices: (..., topk), input_embedding_layer.weight: (vocab_size, embed_dim)
            topk_embeddings = embedding_weight[topk_indices]  # (..., topk, embed_dim)
            
            # Compute weighted sum: (..., topk) @ (..., topk, embed_dim) -> (..., embed_dim)
            new_embed = torch.sum(topk_probs.unsqueeze(-1) * topk_embeddings, dim=-2).to(dtype=dtype)
        else:
            # Original full computation when topk is not specified
            new_embed = (torch.softmax(logits, dim=-1) @ embedding_weight).to(dtype=dtype)

        if self.learnable_weight:
            scalar_weight = torch.sigmoid(self.weight / self.scalar_temperature)
        else:
            scalar_weight = self.weight
        
        return scalar_weight * prev_inputs + (1 - scalar_weight) * new_embed

@register_input_updater
@capture_init_args
class MLPUpdater(InputUpdater):
    """
    MLP updater that concatenates prev_inputs with topk new_embed and uses MLP to generate new embeddings.
    
    Efficiently handles tensors of arbitrary shape (..., embed_dim), preserving all leading
    dimensions while applying the MLP transformation only to the embedding dimension.
    """

    def __init__(self, embed_dim: int, topk: Optional[int] = None, hidden_dim: Optional[int] = None, num_layers: int = 1, dtype: torch.dtype = torch.bfloat16):
        super().__init__()
        self.topk = topk
        self.hidden_dim = hidden_dim
        self.num_layers = num_layers
        self.embed_dim = embed_dim
        
        # Initialize the update network directly
        self._init_network(dtype)

    def _init_network(self, dtype: torch.dtype) -> None:
        """Initialize the update network."""
        # Calculate input dimension for update network (prev_inputs + new_embed)
        input_dim = self.embed_dim * 2
        
        # Set hidden dimension if not specified
        if self.hidden_dim is None:
            self.hidden_dim = self.embed_dim
        
        # Build update network layers
        layers = []
        current_dim = input_dim
        
        # Add hidden layers
        for _ in range(self.num_layers - 1):
            layers.append(nn.Linear(current_dim, self.hidden_dim, dtype=dtype))
            layers.append(nn.GELU())
            current_dim = self.hidden_dim
        
        # Add output layer
        layers.append(nn.Linear(current_dim, self.embed_dim, dtype=dtype))
        
        self.update_network = nn.Sequential(*layers)
    
    def forward(
        self,
        logits: torch.Tensor,
        prev_inputs: torch.Tensor,
        embedding_weight: torch.Tensor,
        hidden_states: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        # All operations preserve leading dimensions (...) and work efficiently
        # with arbitrary tensor shapes
        
        # Apply topk filtering if specified
        if self.topk is not None:
            # Get topk values and indices
            topk_values, topk_indices = torch.topk(logits, k=min(self.topk, logits.size(-1)), dim=-1)
            
            # Only compute softmax on topk values to reduce computation
            topk_probs = torch.softmax(topk_values, dim=-1)  # (..., topk)
            
            # Select corresponding embeddings from embedding layer
            # topk_indices: (..., topk), input_embedding_layer.weight: (vocab_size, embed_dim)
            topk_embeddings = embedding_weight[topk_indices]  # (..., topk, embed_dim)
            
            # Compute weighted sum: (..., topk) @ (..., topk, embed_dim) -> (..., embed_dim)
            new_embed = torch.sum(topk_probs.unsqueeze(-1) * topk_embeddings, dim=-2)
        else:
            # Original full computation when topk is not specified
            new_embed = torch.softmax(logits, dim=-1) @ embedding_weight
        
        # Concatenate prev_inputs and new_embed: (..., embed_dim) + (..., embed_dim) -> (..., 2*embed_dim)
        concat_inputs = torch.cat([prev_inputs, new_embed], dim=-1)
        
        # Pass through update network: (..., 2*embed_dim) -> (..., embed_dim)
        # MLP operations preserve all leading dimensions
        updated_inputs = self.update_network(concat_inputs)
        
        return updated_inputs

@register_input_updater
@capture_init_args
class TrivialUpdater(InputUpdater):
    """
    Trivial update that directly returns logits-weighted embeddings.
    
    Efficiently handles tensors of arbitrary shape (..., vocab_size), preserving
    all leading dimensions while computing weighted embeddings.
    """

    def __init__(self, use_hidden_states = False, topk: Optional[int] = None):
        super().__init__()
        self.use_hidden_states = use_hidden_states
        self.topk = topk

    def forward(
        self,
        logits: torch.Tensor,
        prev_inputs: torch.Tensor,
        embedding_weight: torch.Tensor,
        hidden_states: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        # Direct matrix multiplication preserves all leading dimensions: (..., vocab_size) @ (vocab_size, embed_dim) -> (..., embed_dim)
        if self.use_hidden_states:
            return hidden_states[...,-1,:] # shape: seq_len, num_layer, embed_dim
        else:
            if self.topk is not None:
                topk_values, topk_indices = torch.topk(logits, k=min(self.topk, logits.size(-1)), dim=-1)
                topk_probs = torch.softmax(topk_values, dim=-1)
                topk_embeddings = embedding_weight[topk_indices]
                return torch.sum(topk_probs.unsqueeze(-1) * topk_embeddings, dim=-2)
            else:
                return torch.softmax(logits, dim=-1) @ embedding_weight


class TransformerUpdaterBlock(nn.Module):
    """
    A single transformer-style block for the updater backbone.
    Implements layer normalization, MLP with expansion, and residual connections.
    """
    def __init__(
        self,
        input_dim,
        output_dim,
        expansion_factor=4,
        dropout_rate=0.1,
        dtype: torch.dtype = torch.bfloat16,
    ):
        super().__init__()
        self.input_dim = input_dim
        self.output_dim = output_dim
        
        # Pre-layer normalization
        self.layer_norm = nn.LayerNorm(input_dim, dtype=dtype)
        
        # MLP with expansion
        self.mlp = nn.Sequential(
            nn.Linear(input_dim, input_dim * expansion_factor, dtype=dtype),
            nn.GELU(),
            nn.Dropout(dropout_rate),
            nn.Linear(input_dim * expansion_factor, output_dim, dtype=dtype),
            nn.Dropout(dropout_rate),
        )
        
        # Dimension change projection if needed
        self.dim_change = (
            nn.Linear(input_dim, output_dim, dtype=dtype)
            if input_dim != output_dim
            else nn.Identity()
        )
        
    def forward(self, x):
        # Apply layer norm
        normalized = self.layer_norm(x)
        
        # Handle dimension change for residual if needed
        residual = self.dim_change(x)
        
        # MLP block with residual connection
        return residual + self.mlp(normalized)


class TransformerUpdaterBackbone(nn.Module):
    """
    Backbone architecture for transformer-style updater.
    Implements transformer-style MLP blocks with residual connections.
    """

    def __init__(
        self,
        input_dim,
        output_dim,
        hidden_dims=[1024, 2048, 1024],  # Bottleneck structure
        expansion_factor=4,  # Transformer-style expansion
        dropout_rate=0.1,
        dtype: torch.dtype = torch.bfloat16,
    ):
        super().__init__()
        
        # Build transformer-style blocks
        self.blocks = nn.ModuleList()
        
        # First projection to match the first hidden dimension
        self.input_projection = nn.Linear(input_dim, hidden_dims[0], dtype=dtype)
        
        block_dims = hidden_dims + [hidden_dims[-1]]  # for the last block, the output dimension is the same as the input dimension

        # Create transformer blocks with proper input and output dimensions
        for i in range(len(block_dims) - 1):
            # For the last block, output dimension is the same as input
            # For other blocks, output dimension is the next hidden dimension
            block_input_dim = block_dims[i]
            block_output_dim = block_dims[i+1]
            
            # Create a single block
            self.blocks.append(
                TransformerUpdaterBlock(
                    input_dim=block_input_dim,
                    output_dim=block_output_dim,
                    expansion_factor=expansion_factor,
                    dropout_rate=dropout_rate,
                    dtype=dtype,
                )
            )
        
        # Output layer
        self.output_layer = nn.Sequential(
            nn.LayerNorm(hidden_dims[-1], dtype=dtype),
            nn.Linear(hidden_dims[-1], output_dim, dtype=dtype),
        )
        
        # Initialize weights
        self._init_weights()

    def _init_weights(self):
        for module in self.modules():
            if isinstance(module, nn.Linear):
                nn.init.kaiming_normal_(module.weight)
                if module.bias is not None:
                    nn.init.constant_(module.bias, 0)
    
    def forward(self, x):
        """
        Forward pass through the backbone.
        
        Args:
            x: Input tensor
            
        Returns:
            Model output
        """
        # Project input to first hidden dimension
        x = self.input_projection(x)
        
        # Process through transformer blocks
        for block in self.blocks:
            x = block(x)
        
        # Output layer
        return self.output_layer(x)


@register_input_updater
@capture_init_args
class TransformerUpdater(InputUpdater):
    """
    Transformer-style updater that uses transformer-style MLP blocks.
    
    Uses the same preprocessing as MLPUpdater (concatenation of prev_inputs and new_embed)
    but with transformer-style architecture instead of simple MLP.
    """

    def __init__(
        self, 
        embed_dim: int, 
        topk: Optional[int] = None, 
        hidden_dims: Optional[list] = None,
        expansion_factor: int = 4,
        dropout_rate: float = 0.1,
        dtype: torch.dtype = torch.bfloat16,
    ):
        super().__init__()
        self.topk = topk
        self.embed_dim = embed_dim
        self.expansion_factor = expansion_factor
        self.dropout_rate = dropout_rate
        
        # Set default hidden dimensions if not provided
        if hidden_dims is None:
            self.hidden_dims = [embed_dim, embed_dim * 2, embed_dim]
        else:
            self.hidden_dims = hidden_dims
        
        # Initialize the transformer backbone
        self._init_backbone(dtype)

    def _init_backbone(self, dtype: torch.dtype) -> None:
        """Initialize the transformer backbone."""
        # Input dimension is embed_dim * 2 (prev_inputs + new_embed)
        input_dim = self.embed_dim * 2
        
        # Output dimension is embed_dim
        output_dim = self.embed_dim
        
        # Create transformer backbone
        self.backbone = TransformerUpdaterBackbone(
            input_dim=input_dim,
            output_dim=output_dim,
            hidden_dims=self.hidden_dims,
            expansion_factor=self.expansion_factor,
            dropout_rate=self.dropout_rate,
            dtype=dtype,
        )
    
    def forward(
        self,
        logits: torch.Tensor,
        prev_inputs: torch.Tensor,
        embedding_weight: torch.Tensor,
        hidden_states: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        # Same preprocessing as MLPUpdater
        # All operations preserve leading dimensions (...) and work efficiently
        # with arbitrary tensor shapes
        
        # Apply topk filtering if specified
        if self.topk is not None:
            # Get topk values and indices
            topk_values, topk_indices = torch.topk(logits, k=min(self.topk, logits.size(-1)), dim=-1)
            
            # Only compute softmax on topk values to reduce computation
            topk_probs = torch.softmax(topk_values, dim=-1)  # (..., topk)
            
            # Select corresponding embeddings from embedding layer
            # topk_indices: (..., topk), input_embedding_layer.weight: (vocab_size, embed_dim)
            topk_embeddings = embedding_weight[topk_indices]  # (..., topk, embed_dim)
            
            # Compute weighted sum: (..., topk) @ (..., topk, embed_dim) -> (..., embed_dim)
            new_embed = torch.sum(topk_probs.unsqueeze(-1) * topk_embeddings, dim=-2)
        else:
            # Original full computation when topk is not specified
            new_embed = torch.softmax(logits, dim=-1) @ embedding_weight
        
        # Concatenate prev_inputs and new_embed: (..., embed_dim) + (..., embed_dim) -> (..., 2*embed_dim)
        concat_inputs = torch.cat([prev_inputs, new_embed], dim=-1)
        
        # Pass through transformer backbone: (..., 2*embed_dim) -> (..., embed_dim)
        # Transformer operations preserve all leading dimensions
        updated_inputs = self.backbone(concat_inputs)
        
        return updated_inputs

@register_input_updater
@capture_init_args
class NeuralAdditiveUpdater(InputUpdater):
    """
    Neural additive updater that uses MLP to compute sophisticated additive updates.
    
    Unlike the simple AdditiveUpdater, this uses a neural network to process either
    the new_embed (from weighted logits) or hidden_states to compute a more sophisticated 
    additive component. The update formula is (1-a)*prev_inputs + a*additive_component, 
    where 'a' starts at 0 and the network is zero-initialized, ensuring the model starts 
    with prev_inputs. Efficiently handles tensors of arbitrary shape (..., embed_dim), 
    preserving all leading dimensions while applying neural transformations.
    """

    def __init__(
        self,
        embed_dim: int,
        topk: Optional[int] = None,
        hidden_dim: int = 1024,
        num_layers: int = 2,
        dropout: float = 0.1,
        init_weight: float = 0.0,
        activation: str = "gelu",
        use_layer_norm: bool = True,
        use_hidden_states: bool = False,
        zero_initialize: bool = False,
        dtype: torch.dtype = torch.bfloat16,
        scalar_temperature: float = 1.0,
    ):
        super().__init__()

        self.embed_dim = embed_dim
        self.topk = topk
        self.hidden_dim = hidden_dim
        self.num_layers = num_layers
        self.use_hidden_states = use_hidden_states
        self.scalar_temperature = scalar_temperature

        # Activation function
        if activation.lower() == "gelu":
            self.activation = nn.GELU()
        elif activation.lower() == "relu":
            self.activation = nn.ReLU()
        elif activation.lower() == "silu":
            self.activation = nn.SiLU()
        else:
            raise ValueError(f"Unsupported activation: {activation}")

        # Build neural network for computing additive component
        # Input: new_embed (from logits) or hidden_states (embed_dim)
        # Output: additive component (embed_dim)
        self.additive_network = self._build_mlp(
            embed_dim, hidden_dim, embed_dim, num_layers, 
            use_layer_norm, dropout, dtype
        )

        # Learnable weight for the additive component (initialized as 0)
        self.additive_weight = nn.Parameter(
            torch.tensor(init_weight, dtype=dtype, requires_grad=True)
        )

        # Zero-initialize the additive network so updated inputs start with prev inputs
        if zero_initialize:
            self._zero_init_network()

    def _build_mlp(
        self,
        input_dim: int,
        hidden_dim: int,
        output_dim: int,
        num_layers: int,
        use_layer_norm: bool,
        dropout: float,
        dtype: torch.dtype
    ) -> nn.Sequential:
        """Build MLP for computing additive component"""
        layers = []

        if num_layers == 1:
            # Single layer case
            layers.append(nn.Linear(input_dim, output_dim, dtype=dtype))
            return nn.Sequential(*layers)

        # Input projection
        layers.append(nn.Linear(input_dim, hidden_dim, dtype=dtype))
        if use_layer_norm:
            layers.append(nn.LayerNorm(hidden_dim, dtype=dtype))
        layers.append(copy.deepcopy(self.activation))
        layers.append(nn.Dropout(dropout))

        # Hidden layers
        for _ in range(num_layers - 2):
            layers.append(nn.Linear(hidden_dim, hidden_dim, dtype=dtype))
            if use_layer_norm:
                layers.append(nn.LayerNorm(hidden_dim, dtype=dtype))
            layers.append(copy.deepcopy(self.activation))
            layers.append(nn.Dropout(dropout))

        # Output projection
        layers.append(nn.Linear(hidden_dim, output_dim, dtype=dtype))

        return nn.Sequential(*layers)

    def _zero_init_network(self):
        """Zero-initialize the additive network."""
        for module in self.additive_network.modules():
            if isinstance(module, nn.Linear):
                nn.init.zeros_(module.weight)
                if module.bias is not None:
                    nn.init.zeros_(module.bias)

    def forward(
        self,
        prev_inputs: torch.Tensor,
        logits: Optional[torch.Tensor] = None,
        embedding_weight: Optional[torch.Tensor] = None,
        hidden_states: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        """
        Compute neural additive update.
        
        Args:
            logits: The logits from the token, shape (..., vocab_size)
            prev_inputs: The previous inputs, shape (..., embed_dim)
            embedding_weight: The embedding weight tensor, shape (vocab_size, embed_dim)
            hidden_states: The hidden states, shape (..., embed_dim) (optional, required if use_hidden_states=True)

        Returns:
            The updated inputs, shape (..., embed_dim)
        """
        if self.use_hidden_states:
            # Use hidden states directly as input to the additive network
            if hidden_states is None:
                raise ValueError("hidden_states must be provided when use_hidden_states=True")
            network_input = hidden_states
        else:
            # Use weighted logits embedding as input to the additive network
            if logits is None or embedding_weight is None:
                raise ValueError("logits and embedding_weight must be provided when use_hidden_states=False")

            # Apply topk filtering if specified
            if self.topk is not None:
                # Get topk values and indices
                topk_values, topk_indices = torch.topk(logits, k=min(self.topk, logits.size(-1)), dim=-1)

                # Only compute softmax on topk values to reduce computation
                topk_probs = torch.softmax(topk_values, dim=-1)  # (..., topk)

                # Select corresponding embeddings from embedding layer
                # topk_indices: (..., topk), embedding_weight: (vocab_size, embed_dim)
                topk_embeddings = embedding_weight[topk_indices]  # (..., topk, embed_dim)

                # Compute weighted sum: (..., topk) @ (..., topk, embed_dim) -> (..., embed_dim)
                network_input = torch.sum(topk_probs.unsqueeze(-1) * topk_embeddings, dim=-2)
            else:
                # Original full computation when topk is not specified
                network_input = torch.softmax(logits, dim=-1) @ embedding_weight

        # Compute additive component using neural network
        # (..., embed_dim) -> (..., embed_dim)
        additive_component = self.additive_network(network_input)

        scalar_weight = torch.sigmoid(self.additive_weight / self.scalar_temperature)
        # Apply additive update: (1-a)*prev + a*additive
        updated_inputs = scalar_weight * prev_inputs + (1 - scalar_weight) * additive_component

        return updated_inputs


@register_input_updater
@capture_init_args
class HiddenMLPUpdater(InputUpdater):
    """
    Use selected layers from previous iteration's all_hidden_states to produce next input embeddings.

    Behavior:
    - Select layers by indices `hidden_states_layer_nums` from `all_hidden_states` (shape: (..., L, H))
    - Concatenate selected layers along the feature dimension -> (..., num_selected * H)
    - Apply an MLP to map to `embed_dim`

    Notes:
    - This updater relies solely on hidden states; logits/embedding_weight are ignored.
    - If a single hidden state is provided with shape (..., H), it is treated as one-layer input.
    - If `hidden_states` is None, falls back to zeros based on `hidden_states_size` and `hidden_states_layer_nums`.
    """

    def __init__(
        self,
        embed_dim: int=1024,
        hidden_states_size: int=1024,
        hidden_states_layer_nums: list=[28],
        hidden_dims: list = [256, 512, 256],
        dtype: torch.dtype = torch.bfloat16,
    ):
        super().__init__()
        self.embed_dim = int(embed_dim)
        self.hidden_states_size = int(hidden_states_size)
        self.hidden_states_layer_nums = list(hidden_states_layer_nums or [])
        self.hidden_dims = list(hidden_dims or [256, 512, 256])

        # Build simple MLP: [concat_selected_hidden] -> hidden_dims -> embed_dim
        num_selected = max(1, len(self.hidden_states_layer_nums))
        input_dim = self.hidden_states_size * num_selected

        layers = []
        current_dim = input_dim
        for i, next_dim in enumerate(self.hidden_dims):
            layers.append(nn.Linear(current_dim, next_dim, dtype=dtype))
            layers.append(nn.GELU())
            current_dim = next_dim
        layers.append(nn.Linear(current_dim, self.embed_dim, dtype=dtype))
        self.update_network = nn.Sequential(*layers)

    def _select_and_concat_layers(self, hidden_states: torch.Tensor, target_leading_shape: torch.Size) -> torch.Tensor:
        """Select requested layers and concatenate features to shape (..., num_selected*H)."""
        # Normalize to (..., L, H)
        hs = hidden_states
        if hs.dim() == len(target_leading_shape) + 1:  # (..., H) -> (..., 1, H)
            hs = hs.unsqueeze(-2)
        total_layers = hs.size(-2)

        # Determine indices
        if len(self.hidden_states_layer_nums) == 0:
            indices = [total_layers - 1]
        else:
            indices = self.hidden_states_layer_nums

        index_tensor = torch.as_tensor(indices, device=hs.device, dtype=torch.long)
        if index_tensor.numel() == 0:
            raise ValueError("hidden_states_layer_nums must not be empty")
        if torch.min(index_tensor).item() < 0 or torch.max(index_tensor).item() >= total_layers:
            raise ValueError(
                f"hidden_states_layer_nums out of range: {indices}, total_layers={total_layers}"
            )

        # Select and reshape
        selected = torch.index_select(hs, dim=-2, index=index_tensor)  # (..., num_selected, H)
        concat = selected.reshape(*target_leading_shape, selected.size(-2) * self.hidden_states_size)
        return concat

    def forward(
        self,
        logits: torch.Tensor,
        prev_inputs: torch.Tensor,
        embedding_weight: torch.Tensor,
        hidden_states: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        # Determine leading shape (...)
        leading_shape = prev_inputs.shape[:-1]

        if hidden_states is None:
            # Fallback to zeros if all_hidden_states are not provided
            num_selected = max(1, len(self.hidden_states_layer_nums))
            concat = torch.zeros(
                *leading_shape,
                self.hidden_states_size * num_selected,
                device=prev_inputs.device,
                dtype=prev_inputs.dtype,
            )
        else:
            concat = self._select_and_concat_layers(hidden_states, target_leading_shape=leading_shape)

        # Map concatenated features to embed_dim
        updated_inputs = self.update_network(concat)
        return updated_inputs

def save_input_updater(updater: InputUpdater, save_directory: str):
    """Save updater state dict and configuration."""
    # Use captured initialization arguments from the decorator
    init_args = getattr(updater, '_init_args', {})
    
    # Use natural state_dict - no overrides needed
    state_dict = updater.state_dict()
    state_dict = {k: v.cpu() for k, v in state_dict.items()}
    data = {
        "class": updater.__class__.__name__,
        "state_dict": state_dict,
        "init_args": init_args,
    }
    
    save_path = os.path.join(save_directory, "input_updater.bin")
    print(f"Saving input updater with {len(state_dict)} parameters to {save_path}")
    torch.save(data, save_path)


def load_input_updater(load_directory: str, class_name: Optional[str] = None, init_args: Optional[dict] = None) -> InputUpdater:
    """Load updater from directory."""
    path = os.path.join(load_directory, "input_updater.bin")
    
    if not os.path.isfile(path):
        raise FileNotFoundError(f"No input updater found at {path}")
    
    data = torch.load(path, map_location="cpu")
    if class_name is None:
        class_name = data.get("class")
    
    if not class_name:
        raise ValueError("No updater class specified in saved data")
    
    # Get constructor arguments if available
    if init_args is None:
        init_args = data.get("init_args", {})
    
    # Create updater instance using registry with proper arguments
    updater_class = get_input_updater_class(class_name)
    updater = updater_class(**init_args)
    
    # Load state dict if available - natural loading
    state_dict = data.get("state_dict", {})
    if state_dict:
        # Filter out state_dict keys that conflict with init_args
        filtered_state_dict = {}
        for key, value in state_dict.items():
            if key not in init_args:
                filtered_state_dict[key] = value
            else:
                print(f"Skipping state_dict key '{key}' as it conflicts with init_args")
        
        print(f"Loading input updater state dict with {len(filtered_state_dict)} parameters (filtered from {len(state_dict)})")
        if filtered_state_dict:
            # print(filtered_state_dict.values())
            updater.load_state_dict(filtered_state_dict, strict=False)
    
    # If learnable_weight is explicitly set to false in init_args, freeze all parameters
    # This ensures no gradients are computed for any parameter in the updater
    lw_value = init_args.get("learnable_weight", None)
    should_freeze = False
    if isinstance(lw_value, bool):
        should_freeze = (lw_value is False)
    elif isinstance(lw_value, str):
        should_freeze = lw_value.strip().lower() in ("false", "0", "no")
    if should_freeze:
        for param in updater.parameters():
            param.requires_grad = False
    
    return updater
