"""
Projector nn module for the unified memory
"""

import torch
import torch.nn as nn
from torch import Tensor
from transformers import Cache, DynamicCache
import copy
from typing import Optional, Tuple
import math

from rosetta.utils.registry import register_model, get_projector_class, PROJECTOR_REGISTRY, capture_init_args, save_object, load_object

class Projector(nn.Module):
    """Base projector class for unified memory"""
    
    def forward(self, source_kv: Tuple[Tensor, Tensor], target_kv: Tuple[Tensor, Tensor]) -> Tuple[Tensor, Tensor]:
        """
        Project and combine the source key-value tensors to the target key-value tensors
        Args:
            source_kv: Tuple of (key, value) tensors, each (..., D_s) where ... are arbitrary leading dimensions
            target_kv: Tuple of (key, value) tensors, each (..., D_t) where ... are arbitrary leading dimensions
        Returns:
            Tuple of (key, value) tensors, each (..., D_t) with same leading dimensions as input
        """
        raise NotImplementedError("Subclasses must implement forward method")

    def cache_project(self, source_kv_cache: Cache, target_kv_cache: Cache) -> Cache:
        """
        Project the source kv cache to the target kv cache
        """
        if not isinstance(source_kv_cache, DynamicCache) or not isinstance(target_kv_cache, DynamicCache):
            raise ValueError("Only DynamicCache is supported")
        
        projected_cache = DynamicCache()
        
        # Process each layer
        for layer_idx in range(len(source_kv_cache.key_cache)):
            source_key = source_kv_cache.key_cache[layer_idx]  # (B, H, N, D_s)
            source_value = source_kv_cache.value_cache[layer_idx]  # (B, H, N, D_s)
            
            # Get corresponding target tensors (for reference/combination)
            if layer_idx < len(target_kv_cache.key_cache):
                target_key = target_kv_cache.key_cache[layer_idx]  # (B, H, N, D_t)
                target_value = target_kv_cache.value_cache[layer_idx]  # (B, H, N, D_t)
            else:
                # If target cache doesn't have this layer, create dummy tensors
                B, H, N, D_s = source_key.shape
                D_t = source_key.shape[-1]  # Assume same dimension for simplicity
                target_key = torch.zeros(B, H, N, D_t, device=source_key.device, dtype=source_key.dtype)
                target_value = torch.zeros(B, H, N, D_t, device=source_value.device, dtype=source_value.dtype)
            
            # Reshape for forward pass: DynamicCache format (B, H, N, D) -> projector format (B, N, H, D)
            source_key_reshaped = source_key.transpose(1, 2)
            source_value_reshaped = source_value.transpose(1, 2)
            target_key_reshaped = target_key.transpose(1, 2)
            target_value_reshaped = target_value.transpose(1, 2)
            
            # Project using forward method with tuple input/output
            source_kv = (source_key_reshaped, source_value_reshaped)
            target_kv = (target_key_reshaped, target_value_reshaped)
            projected_key, projected_value = self.forward(source_kv, target_kv)
            
            # Reshape back: projector format (B, N, H, D) -> DynamicCache format (B, H, N, D)
            projected_key = projected_key.transpose(1, 2)
            projected_value = projected_value.transpose(1, 2)
            
            # Update cache
            projected_cache.update(projected_key, projected_value, layer_idx)
        
        return projected_cache

@register_model
@capture_init_args
class TrivialProjector(Projector):
    """
    Trivial projector that directly outputs the target key-value pairs without any modification.
    This is useful as a baseline or when you want to effectively disable projection.
    """
    
    def __init__(self, **kwargs):
        """
        Initialize the trivial projector.
        
        Args:
            source_dim: Source dimension (ignored, kept for compatibility)
            target_dim: Target dimension (ignored, kept for compatibility)
            **kwargs: Additional arguments (ignored, kept for compatibility)
        """
        super().__init__()
    
    def forward(self, source_kv: Tuple[Tensor, Tensor], target_kv: Tuple[Tensor, Tensor]) -> Tuple[Tensor, Tensor]:
        """
        Return the target key-value pairs unchanged, ignoring the source.
        
        Args:
            source_kv: Tuple of (key, value) tensors (ignored)
            target_kv: Tuple of (key, value) tensors to return unchanged
        Returns:
            The target key-value pairs unchanged
        """
        return target_kv


@register_model
@capture_init_args
class MLPProjector(Projector):
    """
    MLP-based projector that uses multi-layer perceptrons to project between different dimensions
    """
    
    def __init__(
        self,
        source_dim: int,
        target_dim: int,
        hidden_dim: int = 512,
        num_layers: int = 2,
        dropout: float = 0.1,
        activation: str = "gelu",
        use_layer_norm: bool = True,
        residual_connection: bool = True,
        dtype: torch.dtype = torch.float32
    ):
        super().__init__()
        
        self.source_dim = source_dim
        self.target_dim = target_dim
        self.hidden_dim = hidden_dim
        self.num_layers = num_layers

        if residual_connection:
            assert source_dim == target_dim, "Residual connection requires source and target dimensions to match"

        self.residual_connection = residual_connection
        
        # Activation function
        if activation.lower() == "gelu":
            self.activation = nn.GELU()
        elif activation.lower() == "relu":
            self.activation = nn.ReLU()
        elif activation.lower() == "silu":
            self.activation = nn.SiLU()
        else:
            raise ValueError(f"Unsupported activation: {activation}")
        
        # Build separate MLP layers for key and value
        self.key_projection = self._build_mlp(source_dim, hidden_dim, target_dim, num_layers, use_layer_norm, dropout, dtype)
        self.value_projection = self._build_mlp(source_dim, hidden_dim, target_dim, num_layers, use_layer_norm, dropout, dtype)
        
        # Optional target combination layers (separate for key and value)
        self.key_combiner = nn.Linear(target_dim * 2, target_dim, dtype=dtype)
        self.value_combiner = nn.Linear(target_dim * 2, target_dim, dtype=dtype)
    
    def _build_mlp(self, source_dim: int, hidden_dim: int, target_dim: int, num_layers: int, 
                   use_layer_norm: bool, dropout: float, dtype: torch.dtype) -> nn.Sequential:
        """Build a single MLP projection module"""
        layers = []
        
        # Input projection
        layers.append(nn.Linear(source_dim, hidden_dim, dtype=dtype))
        if use_layer_norm:
            layers.append(nn.LayerNorm(hidden_dim, dtype=dtype))
        layers.append(copy.deepcopy(self.activation))
        layers.append(nn.Dropout(dropout))
        
        # Hidden layers
        for _ in range(num_layers - 2):
            layers.append(nn.Linear(hidden_dim, hidden_dim, dtype=dtype))
            if use_layer_norm:
                layers.append(nn.LayerNorm(hidden_dim, dtype=dtype))
            layers.append(copy.deepcopy(self.activation))
            layers.append(nn.Dropout(dropout))
        
        # Output projection
        if num_layers > 1:
            layers.append(nn.Linear(hidden_dim, target_dim, dtype=dtype))
        else:
            # Single layer case
            layers = [nn.Linear(source_dim, target_dim, dtype=dtype)]
        
        return nn.Sequential(*layers)
    
    def forward(self, source_kv: Tuple[Tensor, Tensor], target_kv: Tuple[Tensor, Tensor]) -> Tuple[Tensor, Tensor]:
        """
        Project and combine the source key-value tensors to the target key-value tensors
        Args:
            source_kv: Tuple of (key, value) tensors, each (..., D_s) where ... are arbitrary leading dimensions
            target_kv: Tuple of (key, value) tensors, each (..., D_t) where ... are arbitrary leading dimensions
        Returns:
            Tuple of (key, value) tensors, each (..., D_t) with same leading dimensions as input
        """
        source_key, source_value = source_kv
        target_key, target_value = target_kv
        
        # Project source tensors separately for key and value
        projected_key = self.key_projection(source_key)  # (..., D_t)
        projected_value = self.value_projection(source_value)  # (..., D_t)
        
        # Combine with target tensors
        combined_key = torch.cat([projected_key, target_key], dim=-1)  # (..., 2*D_t)
        combined_value = torch.cat([projected_value, target_value], dim=-1)  # (..., 2*D_t)
        
        output_key = self.key_combiner(combined_key)  # (..., D_t)
        output_value = self.value_combiner(combined_value)  # (..., D_t)
        
        # Residual connection if dimensions match
        if self.residual_connection:
            output_key = output_key + target_key
            output_value = output_value + target_value
        
        return (output_key, output_value)


@register_model
@capture_init_args
class SingleLinearReplaceProjector(Projector):
    """
    Replacement projector that projects source key-value tensors to target dimension using MLP,
    then replace target tensors using learnable weights.
    """
    def __init__(
        self,
        source_dim: int,
        target_dim: int,
        hidden_dim: int = 512,
        num_layers: int = 2,
        dropout: float = 0.1,
        activation: str = "gelu",
        use_layer_norm: bool = True,
        init_weight: float = 0.1,
        anneal_steps: int = 1360,
        initial_temperature: float = 1.0,
        final_temperature: float = 0.01,
        scalar_temperature: float = 0.005,
        # shared_key_projection: nn.Module = None,
        # shared_value_projection: nn.Module = None,
        dtype: torch.dtype = torch.float32
    ):
        super().__init__()

        self.source_dim = source_dim
        self.target_dim = target_dim
        self.hidden_dim = hidden_dim
        self.num_layers = num_layers
        
        # Activation function
        if activation.lower() == "gelu":
            self.activation = nn.GELU()
        elif activation.lower() == "relu":
            self.activation = nn.ReLU()
        elif activation.lower() == "silu":
            self.activation = nn.SiLU()
        else:
            raise ValueError(f"Unsupported activation: {activation}")
        
        # Build separate MLP layers for key and value projection
        self.key_projection = nn.Linear(source_dim, target_dim, dtype=dtype)
        self.value_projection = nn.Linear(source_dim, target_dim, dtype=dtype)
    

    def forward(self, source_kv: Tuple[Tensor, Tensor], target_kv: Tuple[Tensor, Tensor]) -> Tuple[Tensor, Tensor]:
        """
        Project source key-value tensors to target dimension and add to target tensors with learnable weights
        Args:
            source_kv: Tuple of (key, value) tensors, each (..., D_s) where ... are arbitrary leading dimensions
            target_kv: Tuple of (key, value) tensors, each (..., D_t) where ... are arbitrary leading dimensions
        Returns:
            Tuple of (key, value) tensors, each (..., D_t) with same leading dimensions as input
        """

        source_key, source_value = source_kv
        target_key, target_value = target_kv
        
        # Get shapes - assuming format is (B, H, N, D) where H is num_heads, N is seq_len, D is head_dim
        source_shape = source_key.shape  # (B, H_s, N, D_s)
        target_shape = target_key.shape  # (B, H_t, N, D_t)
        
        # Extract dimensions
        batch_size, source_num_heads, seq_len, source_head_dim = source_shape
        _, target_num_heads, _, target_head_dim = target_shape
        
        # Reshape source: merge num_heads and head_dim for projection
        # (B, H_s, N, D_s) -> (B, N, H_s * D_s)
        source_key_reshaped = source_key.transpose(1, 2)  # (B, N, H_s, D_s)
        source_value_reshaped = source_value.transpose(1, 2)  # (B, N, H_s, D_s)
        
        source_key_flat = source_key_reshaped.contiguous().view(batch_size, seq_len, source_num_heads * source_head_dim)
        source_value_flat = source_value_reshaped.contiguous().view(batch_size, seq_len, source_num_heads * source_head_dim)
        
        # Project source tensors from (H_s * D_s) to (H_t * D_t)
        projected_key_flat = self.key_projection(source_key_flat)  # (B, N, H_t * D_t)
        projected_value_flat = self.value_projection(source_value_flat)  # (B, N, H_t * D_t)
        
        # Reshape projected tensors back to target format
        # (B, N, H_t * D_t) -> (B, N, H_t, D_t) -> (B, H_t, N, D_t)
        projected_key_reshaped = projected_key_flat.view(batch_size, seq_len, target_num_heads, target_head_dim)
        projected_value_reshaped = projected_value_flat.view(batch_size, seq_len, target_num_heads, target_head_dim)
        
        projected_key = projected_key_reshaped.transpose(1, 2)  # (B, H_t, N, D_t)
        projected_value = projected_value_reshaped.transpose(1, 2)  # (B, H_t, N, D_t)
        
        return (projected_key, projected_value)

@register_model
@capture_init_args
class ReplaceProjector(Projector):
    """
    Replacement projector that projects source key-value tensors to target dimension using MLP,
    then replace target tensors using learnable weights.
    """
    def __init__(
        self,
        source_dim: int,
        target_dim: int,
        hidden_dim: int = 512,
        num_layers: int = 2,
        dropout: float = 0.1,
        activation: str = "gelu",
        use_layer_norm: bool = True,
        init_weight: float = 0.1,
        anneal_steps: int = 1360,
        initial_temperature: float = 1.0,
        final_temperature: float = 0.01,
        scalar_temperature: float = 0.005,
        # shared_key_projection: nn.Module = None,
        # shared_value_projection: nn.Module = None,
        dtype: torch.dtype = torch.float32
    ):
        super().__init__()

        self.source_dim = source_dim
        self.target_dim = target_dim
        self.hidden_dim = hidden_dim
        self.num_layers = num_layers
        
        # Activation function
        if activation.lower() == "gelu":
            self.activation = nn.GELU()
        elif activation.lower() == "relu":
            self.activation = nn.ReLU()
        elif activation.lower() == "silu":
            self.activation = nn.SiLU()
        else:
            raise ValueError(f"Unsupported activation: {activation}")
        
        # Build separate MLP layers for key and value projection
        self.key_projection = self._build_mlp(source_dim, hidden_dim, target_dim, num_layers, use_layer_norm, dropout, dtype)
        self.value_projection = self._build_mlp(source_dim, hidden_dim, target_dim, num_layers, use_layer_norm, dropout, dtype)

    
    def _build_mlp(self, source_dim: int, hidden_dim: int, target_dim: int, num_layers: int, 
                   use_layer_norm: bool, dropout: float, dtype: torch.dtype) -> nn.Sequential:
        """Build a single MLP projection module"""
        layers = []
        
        # Input projection
        layers.append(nn.Linear(source_dim, hidden_dim, dtype=dtype))
        if use_layer_norm:
            layers.append(nn.LayerNorm(hidden_dim, dtype=dtype))
        layers.append(copy.deepcopy(self.activation))
        layers.append(nn.Dropout(dropout))
        
        # Hidden layers
        for _ in range(num_layers - 2):
            layers.append(nn.Linear(hidden_dim, hidden_dim, dtype=dtype))
            if use_layer_norm:
                layers.append(nn.LayerNorm(hidden_dim, dtype=dtype))
            layers.append(copy.deepcopy(self.activation))
            layers.append(nn.Dropout(dropout))
        
        # Output projection
        if num_layers > 1:
            layers.append(nn.Linear(hidden_dim, target_dim, dtype=dtype))
        else:
            # Single layer case
            layers = [nn.Linear(source_dim, target_dim, dtype=dtype)]
        
        return nn.Sequential(*layers)


    def forward(self, source_kv: Tuple[Tensor, Tensor], target_kv: Tuple[Tensor, Tensor]) -> Tuple[Tensor, Tensor]:
        """
        Project source key-value tensors to target dimension and add to target tensors with learnable weights
        Args:
            source_kv: Tuple of (key, value) tensors, each (..., D_s) where ... are arbitrary leading dimensions
            target_kv: Tuple of (key, value) tensors, each (..., D_t) where ... are arbitrary leading dimensions
        Returns:
            Tuple of (key, value) tensors, each (..., D_t) with same leading dimensions as input
        """

        source_key, source_value = source_kv
        target_key, target_value = target_kv
        
        # Get shapes - assuming format is (B, H, N, D) where H is num_heads, N is seq_len, D is head_dim
        source_shape = source_key.shape  # (B, H_s, N, D_s)
        target_shape = target_key.shape  # (B, H_t, N, D_t)
        
        # Extract dimensions
        batch_size, source_num_heads, seq_len, source_head_dim = source_shape
        _, target_num_heads, _, target_head_dim = target_shape
        
        # Reshape source: merge num_heads and head_dim for projection
        # (B, H_s, N, D_s) -> (B, N, H_s * D_s)
        source_key_reshaped = source_key.transpose(1, 2)  # (B, N, H_s, D_s)
        source_value_reshaped = source_value.transpose(1, 2)  # (B, N, H_s, D_s)
        
        source_key_flat = source_key_reshaped.contiguous().view(batch_size, seq_len, source_num_heads * source_head_dim)
        source_value_flat = source_value_reshaped.contiguous().view(batch_size, seq_len, source_num_heads * source_head_dim)
        
        # Project source tensors from (H_s * D_s) to (H_t * D_t)
        projected_key_flat = self.key_projection(source_key_flat)  # (B, N, H_t * D_t)
        projected_value_flat = self.value_projection(source_value_flat)  # (B, N, H_t * D_t)
        
        # Reshape projected tensors back to target format
        # (B, N, H_t * D_t) -> (B, N, H_t, D_t) -> (B, H_t, N, D_t)
        projected_key_reshaped = projected_key_flat.view(batch_size, seq_len, target_num_heads, target_head_dim)
        projected_value_reshaped = projected_value_flat.view(batch_size, seq_len, target_num_heads, target_head_dim)
        
        projected_key = projected_key_reshaped.transpose(1, 2)  # (B, H_t, N, D_t)
        projected_value = projected_value_reshaped.transpose(1, 2)  # (B, H_t, N, D_t)
        
        return (projected_key, projected_value)


@register_model
@capture_init_args
class AdditiveProjector(Projector):
    """
    Additive projector that projects source key-value tensors to target dimension using MLP,
    then adds them to target tensors using learnable weights.
    """
    
    def __init__(
        self,
        source_dim: int,
        target_dim: int,
        source_num_heads: int,
        target_num_heads: int,
        hidden_dim: int = 512,
        num_layers: int = 2,
        dropout: float = 0.1,
        activation: str = "gelu",
        use_layer_norm: bool = True,
        init_weight: float = 0.1,
        anneal_steps: int = 1360,
        initial_temperature: float = 1.0,
        final_temperature: float = 0.01,
        scalar_temperature: float = 0.005,
        # shared_key_projection: nn.Module = None,
        # shared_value_projection: nn.Module = None,
        dtype: torch.dtype = torch.float32
    ):
        super().__init__()
        
        self.source_dim = source_dim
        self.target_dim = target_dim
        self.hidden_dim = hidden_dim
        self.num_layers = num_layers
        
        # Activation function
        if activation.lower() == "gelu":
            self.activation = nn.GELU()
        elif activation.lower() == "relu":
            self.activation = nn.ReLU()
        elif activation.lower() == "silu":
            self.activation = nn.SiLU()
        else:
            raise ValueError(f"Unsupported activation: {activation}")
        
        # Build separate MLP layers for key and value projection
        self.key_projection = self._build_mlp(source_dim * source_num_heads, hidden_dim, target_dim * target_num_heads, num_layers, use_layer_norm, dropout, dtype)
        self.value_projection = self._build_mlp(source_dim * source_num_heads, hidden_dim, target_dim * target_num_heads, num_layers, use_layer_norm, dropout, dtype)
        
        # self.key_projection = shared_key_projection
        # self.value_projection = shared_value_projection
        
        # Learnable weights for additive combination (separate for key and value)
        self.key_weight = nn.Parameter(torch.tensor(init_weight, dtype=dtype))
        self.value_weight = nn.Parameter(torch.tensor(init_weight, dtype=dtype))

        self.gate_logit = nn.Parameter(torch.tensor(0.0, dtype=dtype))

        self.register_buffer("gate_temperature", torch.tensor(1.0))  # initial temperature
        self.initial_temperature = initial_temperature
        self.final_temperature = final_temperature
        self.anneal_steps = anneal_steps
        self.scalar_temperature = scalar_temperature
    
    def _build_mlp(self, source_dim: int, hidden_dim: int, target_dim: int, num_layers: int, 
                   use_layer_norm: bool, dropout: float, dtype: torch.dtype) -> nn.Sequential:
        """Build a single MLP projection module"""
        layers = []
        
        # Input projection
        layers.append(nn.Linear(source_dim, hidden_dim, dtype=dtype))
        if use_layer_norm:
            layers.append(nn.LayerNorm(hidden_dim, dtype=dtype))
        layers.append(copy.deepcopy(self.activation))
        layers.append(nn.Dropout(dropout))
        
        # Hidden layers
        for _ in range(num_layers - 2):
            layers.append(nn.Linear(hidden_dim, hidden_dim, dtype=dtype))
            if use_layer_norm:
                layers.append(nn.LayerNorm(hidden_dim, dtype=dtype))
            layers.append(copy.deepcopy(self.activation))
            layers.append(nn.Dropout(dropout))
        
        # Output projection
        if num_layers > 1:
            layers.append(nn.Linear(hidden_dim, target_dim, dtype=dtype))
        else:
            # Single layer case
            layers = [nn.Linear(source_dim, target_dim, dtype=dtype)]
        
        return nn.Sequential(*layers)

    def update_temperature(self, step: int):
        """
        Update the temperature using exponential annealing schedule
        Args:
            step: current training step
        """
        ratio = min(step / self.anneal_steps, 1.0)
        temp = self.initial_temperature * (self.final_temperature / self.initial_temperature) ** ratio
        self.gate_temperature.fill_(temp)

    @staticmethod
    def sample_gumbel(shape, device, dtype, eps=1e-20):
        U = torch.rand(shape, device=device, dtype=dtype)
        return -torch.log(-torch.log(U + eps) + eps)

    def gumbel_sigmoid_sample(self):
        dev = self.gate_logit.device
        dt = self.gate_logit.dtype
        g0 = self.sample_gumbel(self.gate_logit.size(), dev, dt)
        g1 = self.sample_gumbel(self.gate_logit.size(), dev, dt)
        y = torch.sigmoid((self.gate_logit + g1 - g0) / self.gate_temperature)
        return y

    
    def forward(self, source_kv: Tuple[Tensor, Tensor], target_kv: Tuple[Tensor, Tensor]) -> Tuple[Tensor, Tensor]:
        """
        Project source key-value tensors to target dimension and add to target tensors with learnable weights
        Args:
            source_kv: Tuple of (key, value) tensors, each (..., D_s) where ... are arbitrary leading dimensions
            target_kv: Tuple of (key, value) tensors, each (..., D_t) where ... are arbitrary leading dimensions
        Returns:
            Tuple of (key, value) tensors, each (..., D_t) with same leading dimensions as input
        """
        
        if self.training:
            # Gate 1: temperature-controlled sigmoid
            # gate = torch.sigmoid(self.gate_logit / self.gate_temperature)
            # Gate 2: gumbel-sigmoid
            gate = self.gumbel_sigmoid_sample()
            gate_hard = (gate > 0.5).float()
        else:
            gate = (self.gate_logit > 0).float()
            gate_hard = gate

        source_key, source_value = source_kv
        target_key, target_value = target_kv
        
        # Get shapes - assuming format is (B, H, N, D) where H is num_heads, N is seq_len, D is head_dim
        source_shape = source_key.shape  # (B, H_s, N, D_s)
        target_shape = target_key.shape  # (B, H_t, N, D_t)
        
        # Extract dimensions
        batch_size, source_num_heads, seq_len, source_head_dim = source_shape
        _, target_num_heads, _, target_head_dim = target_shape
        
        # Reshape source: merge num_heads and head_dim for projection
        # (B, H_s, N, D_s) -> (B, N, H_s * D_s)
        source_key_reshaped = source_key.transpose(1, 2)  # (B, N, H_s, D_s)
        source_value_reshaped = source_value.transpose(1, 2)  # (B, N, H_s, D_s)
        
        source_key_flat = source_key_reshaped.contiguous().view(batch_size, seq_len, source_num_heads * source_head_dim)
        source_value_flat = source_value_reshaped.contiguous().view(batch_size, seq_len, source_num_heads * source_head_dim)
        
        # Project source tensors from (H_s * D_s) to (H_t * D_t)
        projected_key_flat = self.key_projection(source_key_flat)  # (B, N, H_t * D_t)
        projected_value_flat = self.value_projection(source_value_flat)  # (B, N, H_t * D_t)
        
        # Reshape projected tensors back to target format
        # (B, N, H_t * D_t) -> (B, N, H_t, D_t) -> (B, H_t, N, D_t)
        projected_key_reshaped = projected_key_flat.view(batch_size, seq_len, target_num_heads, target_head_dim)
        projected_value_reshaped = projected_value_flat.view(batch_size, seq_len, target_num_heads, target_head_dim)
        
        projected_key = projected_key_reshaped.transpose(1, 2)  # (B, H_t, N, D_t)
        projected_value = projected_value_reshaped.transpose(1, 2)  # (B, H_t, N, D_t)
        
        # Add projected source to target with learnable weights
        # output_key = target_key + ((gate_hard - gate).detach() + gate) * self.key_weight * projected_key
        # output_value = target_value + ((gate_hard - gate).detach() + gate) * self.value_weight * projected_value

        normalized_key_weight = torch.sigmoid(self.key_weight / self.scalar_temperature)
        normalized_value_weight = torch.sigmoid(self.value_weight / self.scalar_temperature)
        output_key = (1 - normalized_key_weight) * target_key.clone() + gate * normalized_key_weight * projected_key
        output_value = (1 - normalized_value_weight) * target_value.clone() + gate * normalized_value_weight * projected_value
        # output_key = target_key + self.key_weight * projected_key
        # output_value = target_value + self.value_weight * projected_value
        return (output_key, output_value)

@register_model
@capture_init_args
class ConcatAdditiveProjector(Projector):
    """
    Additive projector that projects source key-value tensors to target dimension using MLP,
    then adds them to target tensors using learnable weights.
    """
    
    def __init__(
        self,
        source_dim: int,
        target_dim: int,
        source_num_heads: int,
        target_num_heads: int,
        hidden_dim: int = 512,
        num_layers: int = 2,
        dropout: float = 0.1,
        activation: str = "gelu",
        use_layer_norm: bool = True,
        init_weight: float = 0.1,
        anneal_steps: int = 1360,
        initial_temperature: float = 1.0,
        final_temperature: float = 0.01,
        scalar_temperature: float = 0.005,
        # shared_key_projection: nn.Module = None,
        # shared_value_projection: nn.Module = None,
        dtype: torch.dtype = torch.float32
    ):
        super().__init__()
        
        self.source_dim = source_dim
        self.target_dim = target_dim
        self.hidden_dim = hidden_dim
        self.num_layers = num_layers
        
        # Activation function
        if activation.lower() == "gelu":
            self.activation = nn.GELU()
        elif activation.lower() == "relu":
            self.activation = nn.ReLU()
        elif activation.lower() == "silu":
            self.activation = nn.SiLU()
        else:
            raise ValueError(f"Unsupported activation: {activation}")
        
        # Build separate MLP layers for key and value projection
        self.key_projection = self._build_mlp(source_dim * source_num_heads, hidden_dim, target_dim * target_num_heads, num_layers, use_layer_norm, dropout, dtype)
        self.value_projection = self._build_mlp(source_dim * source_num_heads, hidden_dim, target_dim * target_num_heads, num_layers, use_layer_norm, dropout, dtype)
        
        # self.key_projection = shared_key_projection
        # self.value_projection = shared_value_projection
        
        # Learnable weights for additive combination (separate for key and value)
        self.key_weight = nn.Parameter(torch.tensor(init_weight, dtype=dtype))
        self.value_weight = nn.Parameter(torch.tensor(init_weight, dtype=dtype))

        # Initialize an empty gate_logit that can be set dynamically in forward
        # self.gate_logit = nn.Parameter(torch.tensor(0.0, dtype=dtype))
        self.gate_logit = None
        
        # Gate generator for input-dependent gating
        # Input: (num_heads * target_dim), Output: 1
        self.gate_generator = nn.Sequential(
                nn.Linear(target_dim, hidden_dim, dtype=dtype),
                # nn.Linear(target_num_heads * target_dim, hidden_dim, dtype=dtype),
                self.activation,
                nn.Linear(hidden_dim, 1, dtype=dtype)
            )

        self.register_buffer("gate_temperature", torch.tensor(1.0))  # initial temperature
        self.initial_temperature = initial_temperature
        self.final_temperature = final_temperature
        self.anneal_steps = anneal_steps
        self.scalar_temperature = scalar_temperature

        self.key_combiner = nn.Linear(target_dim * target_num_heads * 2, target_dim * target_num_heads, dtype=dtype)
        self.value_combiner = nn.Linear(target_dim * target_num_heads * 2, target_dim * target_num_heads, dtype=dtype)

    def _build_mlp(self, source_dim: int, hidden_dim: int, target_dim: int, num_layers: int,
                   use_layer_norm: bool, dropout: float, dtype: torch.dtype) -> nn.Sequential:
        """Build a single MLP projection module"""
        layers = []
        
        # Input projection
        layers.append(nn.Linear(source_dim, hidden_dim, dtype=dtype))
        if use_layer_norm:
            layers.append(nn.LayerNorm(hidden_dim, dtype=dtype))
        layers.append(copy.deepcopy(self.activation))
        layers.append(nn.Dropout(dropout))
        
        # Hidden layers
        for _ in range(num_layers - 2):
            layers.append(nn.Linear(hidden_dim, hidden_dim, dtype=dtype))
            if use_layer_norm:
                layers.append(nn.LayerNorm(hidden_dim, dtype=dtype))
            layers.append(copy.deepcopy(self.activation))
            layers.append(nn.Dropout(dropout))
        
        # Output projection
        if num_layers > 1:
            layers.append(nn.Linear(hidden_dim, target_dim, dtype=dtype))
        else:
            # Single layer case
            layers = [nn.Linear(source_dim, target_dim, dtype=dtype)]
        
        return nn.Sequential(*layers)

    def update_temperature(self, step: int):
        """
        Update the temperature using exponential annealing schedule
        Args:
            step: current training step
        """
        ratio = min(step / self.anneal_steps, 1.0)
        temp = self.initial_temperature * (self.final_temperature / self.initial_temperature) ** ratio
        self.gate_temperature.fill_(temp)

    @staticmethod
    def sample_gumbel(shape, device, dtype, eps=1e-20):
        U = torch.rand(shape, device=device, dtype=dtype)
        return -torch.log(-torch.log(U + eps) + eps)

    def gumbel_sigmoid_sample(self):
        dev = self.gate_logit.device
        dt = self.gate_logit.dtype
        g0 = self.sample_gumbel(self.gate_logit.size(), dev, dt)
        g1 = self.sample_gumbel(self.gate_logit.size(), dev, dt)
        y = torch.sigmoid((self.gate_logit + g1 - g0) / self.gate_temperature)
        return y

    
    def forward(self, source_kv: Tuple[Tensor, Tensor], target_kv: Tuple[Tensor, Tensor]) -> Tuple[Tensor, Tensor]:
        """
        Project source key-value tensors to target dimension and add to target tensors with learnable weights
        Args:
            source_kv: Tuple of (key, value) tensors, each (..., D_s) where ... are arbitrary leading dimensions
            target_kv: Tuple of (key, value) tensors, each (..., D_t) where ... are arbitrary leading dimensions
        Returns:
            Tuple of (key, value) tensors, each (..., D_t) with same leading dimensions as input
        """
        
        self.gate_logit = self.gate_generator(target_kv[0])
        if self.training:
            # Gate 1: temperature-controlled sigmoid
            # gate = torch.sigmoid(self.gate_logit / self.gate_temperature)
            # Gate 2: gumbel-sigmoid
            gate = self.gumbel_sigmoid_sample()
            gate_hard = (gate > 0.5).float()
        else:
            gate = (self.gate_logit > 0).float()
            gate_hard = gate

        source_key, source_value = source_kv
        target_key, target_value = target_kv
        
        # Get shapes - assuming format is (B, H, N, D) where H is num_heads, N is seq_len, D is head_dim
        source_shape = source_key.shape  # (B, H_s, N, D_s)
        target_shape = target_key.shape  # (B, H_t, N, D_t)
        
        # Extract dimensions
        batch_size, source_num_heads, seq_len, source_head_dim = source_shape
        _, target_num_heads, _, target_head_dim = target_shape
        
        # Generate gate logit using flattened target key features
        # Reshape target key: (B, H_t, N, D_t) -> (B, N, H_t * D_t)
        # target_key_reshaped = target_key.clone().transpose(1, 2)  # (B, N, H_t, D_t)
        # target_key_flat = target_key_reshaped.contiguous().view(batch_size, seq_len, target_num_heads * target_head_dim)
        
        # # Generate gate logit: (B, N, H_t * D_t) -> (B, N, 1) -> (B, N)
        # gate_logit_per_token = self.gate_generator(target_key_flat).squeeze(-1)  # (B, N)
        
        # # Expand gate logit to match target key shape: (B, N) -> (B, H_t, N, D_t)
        # self.gate_logit = gate_logit_per_token.unsqueeze(1).unsqueeze(-1).expand(batch_size, target_num_heads, seq_len, target_head_dim)
        
        # if self.training:
        #     # Gate 1: temperature-controlled sigmoid
        #     # gate = torch.sigmoid(self.gate_logit / self.gate_temperature)
        #     # Gate 2: gumbel-sigmoid
        #     gate = self.gumbel_sigmoid_sample()
        #     gate_hard = (gate > 0.5).float()
        # else:
        #     gate = (self.gate_logit > 0).float()
        #     gate_hard = gate
        
        # Reshape source: merge num_heads and head_dim for projection
        # (B, H_s, N, D_s) -> (B, N, H_s * D_s)
        source_key_reshaped = source_key.transpose(1, 2)  # (B, N, H_s, D_s)
        source_value_reshaped = source_value.transpose(1, 2)  # (B, N, H_s, D_s)
        
        source_key_flat = source_key_reshaped.contiguous().view(batch_size, seq_len, source_num_heads * source_head_dim)
        source_value_flat = source_value_reshaped.contiguous().view(batch_size, seq_len, source_num_heads * source_head_dim)
        
        # Project source tensors from (H_s * D_s) to (H_t * D_t)
        projected_key_flat = self.key_projection(source_key_flat)  # (B, N, H_t * D_t)
        projected_value_flat = self.value_projection(source_value_flat)  # (B, N, H_t * D_t)
        
        # # Reshape projected tensors back to target format
        # # (B, N, H_t * D_t) -> (B, N, H_t, D_t) -> (B, H_t, N, D_t)
        # projected_key_reshaped = projected_key_flat.view(batch_size, seq_len, target_num_heads, target_head_dim)
        # projected_value_reshaped = projected_value_flat.view(batch_size, seq_len, target_num_heads, target_head_dim)
        
        # projected_key = projected_key_reshaped.transpose(1, 2)  # (B, H_t, N, D_t)
        # projected_value = projected_value_reshaped.transpose(1, 2)  # (B, H_t, N, D_t)

        target_key_flat = target_key.transpose(1,2).contiguous().view(batch_size, seq_len, target_num_heads * target_head_dim)
        target_value_flat = target_value.transpose(1,2).contiguous().view(batch_size, seq_len, target_num_heads * target_head_dim)
        
         # Combine with target tensors
        combined_key = torch.cat([projected_key_flat, target_key_flat], dim=-1)  # (..., 2*D_t)
        combined_value = torch.cat([projected_value_flat, target_value_flat], dim=-1)  # (..., 2*D_t)

        output_key = self.key_combiner(combined_key)  # (..., D_t)
        output_value = self.value_combiner(combined_value)  # (..., D_t)

        output_key = output_key.view(batch_size, seq_len, target_num_heads, target_head_dim).transpose(1,2)
        output_value = output_value.view(batch_size, seq_len, target_num_heads, target_head_dim).transpose(1,2)

        # Add projected source to target with learnable weights
        # output_key = target_key + ((gate_hard - gate).detach() + gate) * self.key_weight * projected_key
        # output_value = target_value + ((gate_hard - gate).detach() + gate) * self.value_weight * projected_value

        normalized_key_weight = torch.sigmoid(self.key_weight / self.scalar_temperature)
        normalized_value_weight = torch.sigmoid(self.value_weight / self.scalar_temperature)
        output_key = (1 - normalized_key_weight) * target_key.clone() + gate * normalized_key_weight * output_key
        output_value = (1 - normalized_value_weight) * target_value.clone() + gate * normalized_value_weight * output_value
        # output_key = target_key + self.key_weight * projected_key
        # output_value = target_value + self.value_weight * projected_value
        return (output_key, output_value)


@register_model
@capture_init_args
class ExtendAdditiveProjector(Projector):
    """
    Additive projector that projects source key-value tensors to target dimension using MLP,
    then adds them to target tensors using learnable weights.
    """
    
    def __init__(
        self,
        source_dim: int,
        target_dim: int,
        hidden_dim: int = 512,
        num_layers: int = 2,
        dropout: float = 0.1,
        activation: str = "gelu",
        use_layer_norm: bool = True,
        init_type: str = "zero", # 'zero', 'ones', 'uniform', 'normal'
        gate_type: str = "channelwise",  # 'channelwise', 'valuewise' or 'layerwise'
        dtype: torch.dtype = torch.float32
    ):
        super().__init__()
        
        self.source_dim = source_dim
        self.target_dim = target_dim
        self.hidden_dim = hidden_dim
        self.num_layers = num_layers
        
        # Activation function
        if activation.lower() == "gelu":
            self.activation = nn.GELU()
        elif activation.lower() == "relu":
            self.activation = nn.ReLU()
        elif activation.lower() == "silu":
            self.activation = nn.SiLU()
        else:
            raise ValueError(f"Unsupported activation: {activation}")
        
        # Build separate MLP layers for key and value projection
        self.key_projection = self._build_mlp(source_dim, hidden_dim, target_dim, num_layers, use_layer_norm, dropout, dtype)
        self.value_projection = self._build_mlp(source_dim, hidden_dim, target_dim, num_layers, use_layer_norm, dropout, dtype)
        
        if gate_type not in ["channelwise", "valuewise", "layerwise"]:
            raise ValueError(f"Unsupported gate type: {gate_type}. Must be one of 'channelwise', 'valuewise', or 'layerwise'.")
        elif gate_type == "channelwise":
            # Channel-wise gating: learnable weights per channel
            self.key_weight = torch.empty(target_dim, dtype=dtype)
            self.value_weight = torch.empty(target_dim, dtype=dtype)
        elif gate_type == "valuewise":
            # Value-wise gating: learnable weights per value
            self.key_weight = torch.empty(target_dim, source_dim, dtype=dtype)
            self.value_weight = torch.empty(target_dim, source_dim, dtype=dtype)
        elif gate_type == "layerwise":
            self.key_weight = torch.empty(1, dtype=dtype)
            self.value_weight = torch.empty(1, dtype=dtype)
        if init_type == "zero":
            nn.init.zeros_(self.key_weight)
            nn.init.zeros_(self.value_weight)
        elif init_type == "ones":
            nn.init.ones_(self.key_weight)
            nn.init.ones_(self.value_weight)
        elif init_type == "uniform":
            nn.init.uniform_(self.key_weight, a=0.0, b=1.0)
            nn.init.uniform_(self.value_weight, a=0.0, b=1.0)
        elif init_type == "normal":
            nn.init.normal_(self.key_weight, mean=0.0, std=1.0)
            nn.init.normal_(self.value_weight, mean=0.0, std=1.0)   
        else:
            raise ValueError(f"Unsupported init type: {init_type}. Must be one of 'zero', 'ones', or 'random'.")
        
        self.key_weight = nn.Parameter(self.key_weight, requires_grad=True)
        self.value_weight = nn.Parameter(self.value_weight, requires_grad=True)
    
    def _build_mlp(self, source_dim: int, hidden_dim: int, target_dim: int, num_layers: int, 
                   use_layer_norm: bool, dropout: float, dtype: torch.dtype) -> nn.Sequential:
        """Build a single MLP projection module"""
        layers = []
        
        # Input projection
        layers.append(nn.Linear(source_dim, hidden_dim, dtype=dtype))
        if use_layer_norm:
            layers.append(nn.LayerNorm(hidden_dim, dtype=dtype))
        layers.append(copy.deepcopy(self.activation))
        layers.append(nn.Dropout(dropout))
        
        # Hidden layers
        for _ in range(num_layers - 2):
            layers.append(nn.Linear(hidden_dim, hidden_dim, dtype=dtype))
            if use_layer_norm:
                layers.append(nn.LayerNorm(hidden_dim, dtype=dtype))
            layers.append(copy.deepcopy(self.activation))
            layers.append(nn.Dropout(dropout))
        
        # Output projection
        if num_layers > 1:
            layers.append(nn.Linear(hidden_dim, target_dim, dtype=dtype))
        else:
            # Single layer case
            layers = [nn.Linear(source_dim, target_dim, dtype=dtype)]
        
        return nn.Sequential(*layers)
    
    def forward(self, source_kv: Tuple[Tensor, Tensor], target_kv: Tuple[Tensor, Tensor]) -> Tuple[Tensor, Tensor]:
        """
        Project source key-value tensors to target dimension and add to target tensors with learnable weights
        Args:
            source_kv: Tuple of (key, value) tensors, each (..., D_s) where ... are arbitrary leading dimensions
            target_kv: Tuple of (key, value) tensors, each (..., D_t) where ... are arbitrary leading dimensions
        Returns:
            Tuple of (key, value) tensors, each (..., D_t) with same leading dimensions as input
        """
        source_key, source_value = source_kv
        target_key, target_value = target_kv
        
        # Get shapes - assuming format is (B, H, N, D) where H is num_heads, N is seq_len, D is head_dim
        source_shape = source_key.shape  # (B, H_s, N, D_s)
        target_shape = target_key.shape  # (B, H_t, N, D_t)
        
        # Extract dimensions
        batch_size, source_num_heads, seq_len, source_head_dim = source_shape
        _, target_num_heads, _, target_head_dim = target_shape
        
        # Reshape source: merge num_heads and head_dim for projection
        # (B, H_s, N, D_s) -> (B, N, H_s * D_s)
        source_key_reshaped = source_key.transpose(1, 2)  # (B, N, H_s, D_s)
        source_value_reshaped = source_value.transpose(1, 2)  # (B, N, H_s, D_s)
        
        source_key_flat = source_key_reshaped.contiguous().view(batch_size, seq_len, source_num_heads * source_head_dim)
        source_value_flat = source_value_reshaped.contiguous().view(batch_size, seq_len, source_num_heads * source_head_dim)
        
        # Project source tensors from (H_s * D_s) to (H_t * D_t)
        projected_key_flat = self.key_projection(source_key_flat)  # (B, N, H_t * D_t)
        projected_value_flat = self.value_projection(source_value_flat)  # (B, N, H_t * D_t)
        
        # Reshape projected tensors back to target format
        # (B, N, H_t * D_t) -> (B, N, H_t, D_t) -> (B, H_t, N, D_t)
        projected_key_reshaped = projected_key_flat.view(batch_size, seq_len, target_num_heads, target_head_dim)
        projected_value_reshaped = projected_value_flat.view(batch_size, seq_len, target_num_heads, target_head_dim)
        
        projected_key = projected_key_reshaped.transpose(1, 2)  # (B, H_t, N, D_t)
        projected_value = projected_value_reshaped.transpose(1, 2)  # (B, H_t, N, D_t)
        
        # Add projected source to target with learnable weights
        output_key = (self.key_weight) * target_key + self.key_weight * projected_key
        output_value = (self.value_weight) * target_value + self.value_weight * projected_value
        
        return (output_key, output_value)


@register_model
@capture_init_args
class ExtendGatedProjector(Projector):
    """
    Gated projector that projects source key-value tensors to target dimension using MLP,
    then adds them to target tensors using learnable weights. The add
    """
    
    def __init__(
        self,
        source_dim: int,
        target_dim: int,
        hidden_dim: int = 512,
        num_layers: int = 2,
        dropout: float = 0.1,
        activation: str = "gelu",
        use_layer_norm: bool = True,
        init_type: str = "zero", # 'zero', 'ones', 'uniform', 'normal'
        gate_type: str = "channelwise",  # 'channelwise', 'valuewise' or 'layerwise'
        dtype: torch.dtype = torch.float32
    ):
        super().__init__()
        
        self.source_dim = source_dim
        self.target_dim = target_dim
        self.hidden_dim = hidden_dim
        self.num_layers = num_layers
        
        # Activation function
        if activation.lower() == "gelu":
            self.activation = nn.GELU()
        elif activation.lower() == "relu":
            self.activation = nn.ReLU()
        elif activation.lower() == "silu":
            self.activation = nn.SiLU()
        else:
            raise ValueError(f"Unsupported activation: {activation}")
        
        # Build separate MLP layers for key and value projection
        self.key_projection = self._build_mlp(source_dim, hidden_dim, target_dim, num_layers, use_layer_norm, dropout, dtype)
        self.value_projection = self._build_mlp(source_dim, hidden_dim, target_dim, num_layers, use_layer_norm, dropout, dtype)
        
        # Learnable weights for additive combination (separate for key and value)
        if gate_type not in ["channelwise", "layerwise"]:
            raise ValueError(f"Unsupported gate type: {gate_type}. Must be one of 'channelwise', 'valuewise', or 'layerwise'.")
        elif gate_type == "channelwise":
            # Channel-wise gating: learnable weights per channel
            self.key_weight = torch.empty(target_dim, dtype=dtype)
            self.value_weight = torch.empty(target_dim, dtype=dtype)
            self.gate_logit = nn.Parameter(torch.zeros(target_dim, dtype=dtype))
        # elif gate_type == "valuewise":
        #     # Value-wise gating: learnable weights per value
        #     self.key_weight = torch.empty(target_dim, source_dim, dtype=dtype)
        #     self.value_weight = torch.empty(target_dim, source_dim, dtype=dtype)
        #     self.gate_logit = nn.Parameter(torch.zeros((source_dim, target_dim), dtype=dtype))
        elif gate_type == "layerwise":
            self.key_weight = torch.empty(1, dtype=dtype)
            self.value_weight = torch.empty(1, dtype=dtype)
            self.gate_logit = nn.Parameter(torch.tensor(0.0, dtype=dtype))
        if init_type == "zero":
            nn.init.zeros_(self.key_weight)
            nn.init.zeros_(self.value_weight)
        elif init_type == "ones":
            nn.init.ones_(self.key_weight)
            nn.init.ones_(self.value_weight)
        elif init_type == "uniform":
            nn.init.uniform_(self.key_weight, a=0.0, b=1.0)
            nn.init.uniform_(self.value_weight, a=0.0, b=1.0)
        elif init_type == "normal":
            nn.init.normal_(self.key_weight, mean=0.0, std=1.0)
            nn.init.normal_(self.value_weight, mean=0.0, std=1.0)   
        else:
            raise ValueError(f"Unsupported init type: {init_type}. Must be one of 'zero', 'ones', or 'random'.")
        
        self.key_weight = nn.Parameter(self.key_weight, requires_grad=True)
        self.value_weight = nn.Parameter(self.value_weight, requires_grad=True)

        # self.gate_logit = nn.Parameter(torch.tensor(0.0, dtype=dtype))

        self.register_buffer("gate_temperature", torch.tensor(1.0))  # initial temperature
        self.initial_temperature = 1.0
        self.final_temperature = 0.01
        self.anneal_steps = 440
    
    def _build_mlp(self, source_dim: int, hidden_dim: int, target_dim: int, num_layers: int, 
                   use_layer_norm: bool, dropout: float, dtype: torch.dtype) -> nn.Sequential:
        """Build a single MLP projection module"""
        layers = []
        
        # Input projection
        layers.append(nn.Linear(source_dim, hidden_dim, dtype=dtype))
        if use_layer_norm:
            layers.append(nn.LayerNorm(hidden_dim, dtype=dtype))
        layers.append(copy.deepcopy(self.activation))
        layers.append(nn.Dropout(dropout))
        
        # Hidden layers
        for _ in range(num_layers - 2):
            layers.append(nn.Linear(hidden_dim, hidden_dim, dtype=dtype))
            if use_layer_norm:
                layers.append(nn.LayerNorm(hidden_dim, dtype=dtype))
            layers.append(copy.deepcopy(self.activation))
            layers.append(nn.Dropout(dropout))
        
        # Output projection
        if num_layers > 1:
            layers.append(nn.Linear(hidden_dim, target_dim, dtype=dtype))
        else:
            # Single layer case
            layers = [nn.Linear(source_dim, target_dim, dtype=dtype)]
        
        return nn.Sequential(*layers)

    def update_temperature(self, step: int):
        """
        Update the temperature using exponential annealing schedule
        Args:
            step: current training step
        """
        ratio = min(step / self.anneal_steps, 1.0)
        temp = self.initial_temperature * (self.final_temperature / self.initial_temperature) ** ratio
        self.gate_temperature.fill_(temp)

    @staticmethod
    def sample_gumbel(shape, device, dtype, eps=1e-20):
        U = torch.rand(shape, device=device, dtype=dtype)
        return -torch.log(-torch.log(U + eps) + eps)

    def gumbel_sigmoid_sample(self):
        dev = self.gate_logit.device
        dt = self.gate_logit.dtype
        g0 = self.sample_gumbel(self.gate_logit.size(), dev, dt)
        g1 = self.sample_gumbel(self.gate_logit.size(), dev, dt)
        y = torch.sigmoid((self.gate_logit + g1 - g0) / self.gate_temperature)
        return y

    
    def forward(self, source_kv: Tuple[Tensor, Tensor], target_kv: Tuple[Tensor, Tensor]) -> Tuple[Tensor, Tensor]:
        """
        Project source key-value tensors to target dimension and add to target tensors with learnable weights
        Args:
            source_kv: Tuple of (key, value) tensors, each (..., D_s) where ... are arbitrary leading dimensions
            target_kv: Tuple of (key, value) tensors, each (..., D_t) where ... are arbitrary leading dimensions
        Returns:
            Tuple of (key, value) tensors, each (..., D_t) with same leading dimensions as input
        """
        # Gate 1: temperature-controlled sigmoid
        # gate = torch.sigmoid(self.gate_logit / self.gate_temperature)

        # Gate 2: gumbel-sigmoid
        if self.training:
            gate = self.gumbel_sigmoid_sample()
        else:
            gate = (self.gate_logit > 0).float()

        source_key, source_value = source_kv
        target_key, target_value = target_kv
        
        # Get shapes - assuming format is (B, H, N, D) where H is num_heads, N is seq_len, D is head_dim
        source_shape = source_key.shape  # (B, H_s, N, D_s)
        target_shape = target_key.shape  # (B, H_t, N, D_t)
        
        # Extract dimensions
        batch_size, source_num_heads, seq_len, source_head_dim = source_shape
        _, target_num_heads, _, target_head_dim = target_shape
        
        # Reshape source: merge num_heads and head_dim for projection
        # (B, H_s, N, D_s) -> (B, N, H_s * D_s)
        source_key_reshaped = source_key.transpose(1, 2)  # (B, N, H_s, D_s)
        source_value_reshaped = source_value.transpose(1, 2)  # (B, N, H_s, D_s)
        
        source_key_flat = source_key_reshaped.contiguous().view(batch_size, seq_len, source_num_heads * source_head_dim)
        source_value_flat = source_value_reshaped.contiguous().view(batch_size, seq_len, source_num_heads * source_head_dim)
        
        # Project source tensors from (H_s * D_s) to (H_t * D_t)
        projected_key_flat = self.key_projection(source_key_flat)  # (B, N, H_t * D_t)
        projected_value_flat = self.value_projection(source_value_flat)  # (B, N, H_t * D_t)
        
        # Reshape projected tensors back to target format
        # (B, N, H_t * D_t) -> (B, N, H_t, D_t) -> (B, H_t, N, D_t)
        projected_key_reshaped = projected_key_flat.view(batch_size, seq_len, target_num_heads, target_head_dim)
        projected_value_reshaped = projected_value_flat.view(batch_size, seq_len, target_num_heads, target_head_dim)
        
        projected_key = projected_key_reshaped.transpose(1, 2)  # (B, H_t, N, D_t)
        projected_value = projected_value_reshaped.transpose(1, 2)  # (B, H_t, N, D_t)
        
        # Add projected source to target with learnable weights
        output_key = target_key + gate * self.key_weight * projected_key
        output_value = target_value + gate * self.value_weight * projected_value
        # output_key = target_key + self.key_weight * projected_key
        # output_value = target_value + self.value_weight * projected_value
        return (output_key, output_value)

class SwigLUBlock(nn.Module):
    def __init__(self, input_dim, output_dim, dtype):
        super().__init__()
        self.gate_proj = nn.Linear(input_dim, output_dim, dtype=dtype)
        self.act_proj = nn.Linear(input_dim, output_dim, dtype=dtype)
        self.activation = nn.SiLU()

    def forward(self, x):
        gate = self.activation(self.gate_proj(x))
        act = self.act_proj(x)
        x = gate * act
        return x

@register_model
@capture_init_args
class TokenWiseAdditiveProjector(Projector):
    """
    Additive projector that projects source key-value tensors to target dimension using MLP,
    then adds them to target tensors using learnable weights.
    """
    
    def __init__(
        self,
        source_dim: int,
        target_dim: int,
        hidden_dim: int = 512,
        num_layers: int = 2,
        dropout: float = 0.1,
        activation: str = "gelu",
        use_layer_norm: bool = True,
        init_type: str = "zero", # 'zero', 'ones', 'uniform', 'normal'
        # shared_key_projection: nn.Module = None,
        # shared_value_projection: nn.Module = None,
        anneal_steps: int = 1360,
        dtype: torch.dtype = torch.float32
    ):
        super().__init__()
        
        self.source_dim = source_dim
        self.target_dim = target_dim
        self.hidden_dim = hidden_dim
        self.num_layers = num_layers
        
        # Activation function
        if activation.lower() == "gelu":
            self.activation = nn.GELU()
        elif activation.lower() == "relu":
            self.activation = nn.ReLU()
        elif activation.lower() == "silu":
            self.activation = nn.SiLU()
        else:
            raise ValueError(f"Unsupported activation: {activation}")
        
        # Build separate MLP layers for key and value projection
        self.key_projection = self._build_mlp(source_dim, hidden_dim, target_dim, num_layers, use_layer_norm, dropout, dtype)
        self.value_projection = self._build_mlp(source_dim, hidden_dim, target_dim, num_layers, use_layer_norm, dropout, dtype)
        
        # self.key_projection = shared_key_projection
        # self.value_projection = shared_value_projection
        
        # -------------------------------------------
        # Learnable weights for additive combination (separate for key and value)
        # self.key_weight = nn.Parameter(torch.tensor(init_weight, dtype=dtype))
        # self.value_weight = nn.Parameter(torch.tensor(init_weight, dtype=dtype))

        # self.gate_logit = nn.Parameter(torch.tensor(0.0, dtype=dtype))
        # -------------------------------------------

        # Per-token dynamic weight generators (shared hidden layers)
        self.weight_hidden = nn.Sequential(
            nn.Linear(target_dim, hidden_dim, dtype=dtype),
            self.activation,
            nn.Dropout(dropout)
        )

        # Separate heads for key and value weights
        self.key_weight_head = nn.Linear(hidden_dim, 1, dtype=dtype)
        self.value_weight_head = nn.Linear(hidden_dim, 1, dtype=dtype)

        if init_type == "zero":
            nn.init.zeros_(self.key_weight_head.weight)
            nn.init.zeros_(self.value_weight_head.weight)
        elif init_type == "ones":
            nn.init.ones_(self.key_weight_head.weight)
            nn.init.ones_(self.value_weight_head.weight)
        elif init_type == "uniform":
            nn.init.uniform_(self.key_weight_head.weight, a=0.0, b=1.0)
            nn.init.uniform_(self.value_weight_head.weight, a=0.0, b=1.0)
        elif init_type == "normal":
            nn.init.normal_(self.key_weight_head.weight, mean=0.0, std=1.0)
            nn.init.normal_(self.value_weight_head.weight, mean=0.0, std=1.0)  
        
        # Per-token gate generator
        self.gate_generator = nn.Sequential(
            nn.Linear(target_dim, hidden_dim, dtype=dtype),
            self.activation,
            nn.Linear(hidden_dim, 1, dtype=dtype)
        )

        self.register_buffer("gate_temperature", torch.tensor(1.0))  # initial temperature
        self.initial_temperature = 1.0
        self.final_temperature = 0.01
        self.anneal_steps = anneal_steps
        self.scalar_temperature = 0.005
    
    def _build_mlp(self, source_dim: int, hidden_dim: int, target_dim: int, num_layers: int, 
                   use_layer_norm: bool, dropout: float, dtype: torch.dtype) -> nn.Sequential:
        """Build a single MLP projection module"""
        layers = []
        
        # Input projection
        layers.append(nn.Linear(source_dim, hidden_dim, dtype=dtype))
        if use_layer_norm:
            layers.append(nn.LayerNorm(hidden_dim, dtype=dtype))
        layers.append(copy.deepcopy(self.activation))
        layers.append(nn.Dropout(dropout))
        
        # Hidden layers
        for _ in range(num_layers - 2):
            layers.append(nn.Linear(hidden_dim, hidden_dim, dtype=dtype))
            if use_layer_norm:
                layers.append(nn.LayerNorm(hidden_dim, dtype=dtype))
            layers.append(copy.deepcopy(self.activation))
            layers.append(nn.Dropout(dropout))
        
        # Output projection
        if num_layers > 1:
            layers.append(nn.Linear(hidden_dim, target_dim, dtype=dtype))
        else:
            # Single layer case
            layers = [nn.Linear(source_dim, target_dim, dtype=dtype)]
        
        return nn.Sequential(*layers)

    def update_temperature(self, step: int):
        """
        Update the temperature using exponential annealing schedule
        Args:
            step: current training step
        """
        ratio = min(step / self.anneal_steps, 1.0)
        temp = self.initial_temperature * (self.final_temperature / self.initial_temperature) ** ratio
        self.gate_temperature.fill_(temp)

    @staticmethod
    def sample_gumbel(shape, eps=1e-20):
        U = torch.rand(shape)
        return -torch.log(-torch.log(U + eps) + eps)

    def gumbel_sigmoid_sample(self, gate_logit):
        g0 = self.sample_gumbel(gate_logit.size()).to(gate_logit.device)
        g1 = self.sample_gumbel(gate_logit.size()).to(gate_logit.device)
        y = torch.sigmoid((gate_logit + g1 - g0) / self.gate_temperature)
        return y

    
    def forward(self, source_kv: Tuple[Tensor, Tensor], target_kv: Tuple[Tensor, Tensor]) -> Tuple[Tensor, Tensor]:
        """
        Project source key-value tensors to target dimension and add to target tensors with learnable weights
        Args:
            source_kv: Tuple of (key, value) tensors, each (..., D_s) where ... are arbitrary leading dimensions
            target_kv: Tuple of (key, value) tensors, each (..., D_t) where ... are arbitrary leading dimensions
        Returns:
            Tuple of (key, value) tensors, each (..., D_t) with same leading dimensions as input
        """

        source_key, source_value = source_kv
        target_key, target_value = target_kv

        # Get shapes - assuming format is (B, H, N, D) where H is num_heads, N is seq_len, D is head_dim
        source_shape = source_key.shape  # (B, H_s, N, D_s)
        target_shape = target_key.shape  # (B, H_t, N, D_t)
        
        # Extract dimensions
        batch_size, source_num_heads, seq_len, source_head_dim = source_shape
        _, target_num_heads, _, target_head_dim = target_shape
        
        # Reshape source: merge num_heads and head_dim for projection
        # (B, H_s, N, D_s) -> (B, N, H_s * D_s)
        source_key_reshaped = source_key.transpose(1, 2)  # (B, N, H_s, D_s)
        source_value_reshaped = source_value.transpose(1, 2)  # (B, N, H_s, D_s)
        
        source_key_flat = source_key_reshaped.contiguous().view(batch_size, seq_len, source_num_heads * source_head_dim)
        source_value_flat = source_value_reshaped.contiguous().view(batch_size, seq_len, source_num_heads * source_head_dim)
        
        # Project source tensors from (H_s * D_s) to (H_t * D_t)
        projected_key_flat = self.key_projection(source_key_flat)  # (B, N, H_t * D_t)
        projected_value_flat = self.value_projection(source_value_flat)  # (B, N, H_t * D_t)
        
        # Reshape projected tensors back to target format
        # (B, N, H_t * D_t) -> (B, N, H_t, D_t) -> (B, H_t, N, D_t)
        projected_key_reshaped = projected_key_flat.view(batch_size, seq_len, target_num_heads, target_head_dim)
        projected_value_reshaped = projected_value_flat.view(batch_size, seq_len, target_num_heads, target_head_dim)
        
        projected_key = projected_key_reshaped.transpose(1, 2)  # (B, H_t, N, D_t)
        projected_value = projected_value_reshaped.transpose(1, 2)  # (B, H_t, N, D_t)

        # current the weight only depend on slm key cache
        weight_hidden = self.weight_hidden(target_key)

        key_weight = torch.sigmoid(self.key_weight_head(weight_hidden))  # (batch, seq_len, 1)
        value_weight = torch.sigmoid(self.value_weight_head(weight_hidden))  # (batch, seq_len, 1)

        # Generate per-token gates
        gate_logit = self.gate_generator(target_key)  # (batch, seq_len, 1)

        if self.training:
            # Gate 1: temperature-controlled sigmoid
            # gate = torch.sigmoid(self.gate_logit / self.gate_temperature)
            # Gate 2: gumbel-sigmoid
            gate = self.gumbel_sigmoid_sample(gate_logit)
            gate_hard = (gate > 0.5).float()
        else:
            gate = (gate_logit > 0).float()
            gate_hard = gate
        
        # Add projected source to target with learnable weights
        # output_key = target_key + ((gate_hard - gate).detach() + gate) * self.key_weight * projected_key
        # output_value = target_value + ((gate_hard - gate).detach() + gate) * self.value_weight * projected_value

        normalized_key_weight = torch.sigmoid(self.key_weight_head(weight_hidden) / self.scalar_temperature)
        normalized_value_weight = torch.sigmoid(self.value_weight_head(weight_hidden) / self.scalar_temperature)
        output_key = (1 - normalized_key_weight) * target_key.clone() + gate * normalized_key_weight * projected_key
        output_value = (1 - normalized_value_weight) * target_value.clone() + gate * normalized_value_weight * projected_value
        # output_key = target_key + self.key_weight * projected_key
        # output_value = target_value + self.value_weight * projected_value
        return (output_key, output_value)


@register_model
@capture_init_args
class TokenWiseGatedProjector(Projector):
    """
    Gated projector that projects source key-value tensors to target dimension using MLP,
    then adds them to target tensors using learnable weights.
    """
    
    def __init__(
        self,
        source_dim: int,
        target_dim: int,
        source_num_heads: int,
        target_num_heads: int,
        hidden_dim: int = 512,
        num_layers: int = 2,
        dropout: float = 0.1,
        activation: str = "gelu",
        use_layer_norm: bool = True,
        init_type: str = "zero", # 'zero', 'ones', 'uniform', 'normal'
        # shared_key_projection: nn.Module = None,
        # shared_value_projection: nn.Module = None,
        anneal_steps: int = 1360,
        structure_type: str = "channelwise",  # 'channelwise', or 'layerwise'
        dtype: torch.dtype = torch.float32
    ):
        super().__init__()
        
        self.source_dim = source_dim
        self.target_dim = target_dim
        self.source_num_heads = source_num_heads
        self.target_num_heads = target_num_heads
        self.hidden_dim = hidden_dim
        self.num_layers = num_layers
        
        # Activation function
        if activation.lower() == "gelu":
            self.activation = nn.GELU()
        elif activation.lower() == "relu":
            self.activation = nn.ReLU()
        elif activation.lower() == "silu":
            self.activation = nn.SiLU()
        else:
            raise ValueError(f"Unsupported activation: {activation}")
        
        # Build separate MLP layers for key and value projection
        self.key_projection = self._build_mlp(source_dim * source_num_heads, hidden_dim, target_dim * target_num_heads, num_layers, use_layer_norm, dropout, dtype)
        self.value_projection = self._build_mlp(source_dim * source_num_heads, hidden_dim, target_dim * target_num_heads, num_layers, use_layer_norm, dropout, dtype)

        # self.key_projection = shared_key_projection
        # self.value_projection = shared_value_projection
        
        # -------------------------------------------
        # Learnable weights for additive combination (separate for key and value)
        # self.key_weight = nn.Parameter(torch.tensor(init_weight, dtype=dtype))
        # self.value_weight = nn.Parameter(torch.tensor(init_weight, dtype=dtype))

        # self.gate_logit = nn.Parameter(torch.tensor(0.0, dtype=dtype))
        # -------------------------------------------

        # Per-token dynamic weight generators (shared hidden layers)
        self.weight_hidden = nn.Sequential(
            nn.Linear(target_dim, hidden_dim, dtype=dtype),
            self.activation,
            nn.Dropout(dropout)
        )

        if structure_type == 'channelwise':
            # Separate heads for key and value weights
            self.key_weight_head = nn.Linear(hidden_dim, target_dim, dtype=dtype)
            self.value_weight_head = nn.Linear(hidden_dim, target_dim, dtype=dtype)
            # Per-token gate generator
            self.gate_generator = nn.Sequential(
                nn.Linear(target_dim, hidden_dim, dtype=dtype),
                self.activation,
                nn.Linear(hidden_dim, target_dim, dtype=dtype)
            )
        else:
            self.key_weight_head = nn.Linear(hidden_dim, 1, dtype=dtype)
            self.value_weight_head = nn.Linear(hidden_dim, 1, dtype=dtype)
            # Per-token gate generator
            self.gate_generator = nn.Sequential(
                nn.Linear(target_dim, hidden_dim, dtype=dtype),
                self.activation,
                nn.Linear(hidden_dim, 1, dtype=dtype)
            )

        if init_type == "zero":
            nn.init.zeros_(self.key_weight_head.weight)
            nn.init.zeros_(self.value_weight_head.weight)
        elif init_type == "ones":
            nn.init.ones_(self.key_weight_head.weight)
            nn.init.ones_(self.value_weight_head.weight)
        elif init_type == "uniform":
            nn.init.uniform_(self.key_weight_head.weight, a=0.0, b=1.0)
            nn.init.uniform_(self.value_weight_head.weight, a=0.0, b=1.0)
        elif init_type == "normal":
            nn.init.normal_(self.key_weight_head.weight, mean=0.0, std=1.0)
            nn.init.normal_(self.value_weight_head.weight, mean=0.0, std=1.0)  

        self.register_buffer("gate_temperature", torch.tensor(1.0))  # initial temperature
        self.initial_temperature = 1.0
        self.final_temperature = 0.01
        self.anneal_steps = anneal_steps
        self.scalar_temperature = 0.005
    
    def _build_mlp(self, source_dim: int, hidden_dim: int, target_dim: int, num_layers: int, 
                   use_layer_norm: bool, dropout: float, dtype: torch.dtype) -> nn.Sequential:
        """Build a single MLP projection module"""
        layers = []
        
        # Input projection
        layers.append(nn.Linear(source_dim, hidden_dim, dtype=dtype))
        if use_layer_norm:
            layers.append(nn.LayerNorm(hidden_dim, dtype=dtype))
        layers.append(copy.deepcopy(self.activation))
        layers.append(nn.Dropout(dropout))
        
        # Hidden layers
        for _ in range(num_layers - 2):
            layers.append(nn.Linear(hidden_dim, hidden_dim, dtype=dtype))
            if use_layer_norm:
                layers.append(nn.LayerNorm(hidden_dim, dtype=dtype))
            layers.append(copy.deepcopy(self.activation))
            layers.append(nn.Dropout(dropout))
        
        # Output projection
        if num_layers > 1:
            layers.append(nn.Linear(hidden_dim, target_dim, dtype=dtype))
        else:
            # Single layer case
            layers = [nn.Linear(source_dim, target_dim, dtype=dtype)]
        
        return nn.Sequential(*layers)

    def update_temperature(self, step: int):
        """
        Update the temperature using exponential annealing schedule
        Args:
            step: current training step
        """
        ratio = min(step / self.anneal_steps, 1.0)
        temp = self.initial_temperature * (self.final_temperature / self.initial_temperature) ** ratio
        self.gate_temperature.fill_(temp)

    @staticmethod
    def sample_gumbel(shape, eps=1e-20):
        U = torch.rand(shape)
        return -torch.log(-torch.log(U + eps) + eps)

    def gumbel_sigmoid_sample(self, gate_logit):
        g0 = self.sample_gumbel(gate_logit.size()).to(gate_logit.device)
        g1 = self.sample_gumbel(gate_logit.size()).to(gate_logit.device)
        y = torch.sigmoid((gate_logit + g1 - g0) / self.gate_temperature)
        return y

    
    def forward(self, source_kv: Tuple[Tensor, Tensor], target_kv: Tuple[Tensor, Tensor]) -> Tuple[Tensor, Tensor]:
        """
        Project source key-value tensors to target dimension and add to target tensors with learnable weights
        Args:
            source_kv: Tuple of (key, value) tensors, each (..., D_s) where ... are arbitrary leading dimensions
            target_kv: Tuple of (key, value) tensors, each (..., D_t) where ... are arbitrary leading dimensions
        Returns:
            Tuple of (key, value) tensors, each (..., D_t) with same leading dimensions as input
        """

        source_key, source_value = source_kv
        target_key, target_value = target_kv
        
        # Get shapes - assuming format is (B, H, N, D) where H is num_heads, N is seq_len, D is head_dim
        source_shape = source_key.shape  # (B, H_s, N, D_s)
        target_shape = target_key.shape  # (B, H_t, N, D_t)
        
        # Extract dimensions
        batch_size, source_num_heads, seq_len, source_head_dim = source_shape
        _, target_num_heads, _, target_head_dim = target_shape
        
        # Reshape source: merge num_heads and head_dim for projection
        # (B, H_s, N, D_s) -> (B, N, H_s * D_s)
        source_key_reshaped = source_key.transpose(1, 2)  # (B, N, H_s, D_s)
        source_value_reshaped = source_value.transpose(1, 2)  # (B, N, H_s, D_s)
        
        source_key_flat = source_key_reshaped.contiguous().view(batch_size, seq_len, source_num_heads * source_head_dim)
        source_value_flat = source_value_reshaped.contiguous().view(batch_size, seq_len, source_num_heads * source_head_dim)
        
        # Project source tensors from (H_s * D_s) to (H_t * D_t)
        projected_key_flat = self.key_projection(source_key_flat)  # (B, N, H_t * D_t)
        projected_value_flat = self.value_projection(source_value_flat)  # (B, N, H_t * D_t)
        
        # Reshape projected tensors back to target format
        # (B, N, H_t * D_t) -> (B, N, H_t, D_t) -> (B, H_t, N, D_t)
        projected_key_reshaped = projected_key_flat.view(batch_size, seq_len, target_num_heads, target_head_dim)
        projected_value_reshaped = projected_value_flat.view(batch_size, seq_len, target_num_heads, target_head_dim)
        
        projected_key = projected_key_reshaped.transpose(1, 2)  # (B, H_t, N, D_t)
        projected_value = projected_value_reshaped.transpose(1, 2)  # (B, H_t, N, D_t)

        # current the weight only depend on slm key cache
        weight_hidden = self.weight_hidden(target_key)

        # Generate per-token gates
        gate_logit = self.gate_generator(target_key)  # (batch, seq_len, 1)

        if self.training:
            # Gate 1: temperature-controlled sigmoid
            # gate = torch.sigmoid(self.gate_logit / self.gate_temperature)
            # Gate 2: gumbel-sigmoid
            gate = self.gumbel_sigmoid_sample(gate_logit)
            gate_hard = (gate > 0.5).float()
        else:
            gate = (gate_logit > 0).float()
            gate_hard = gate
        
        # Add projected source to target with learnable weights
        # output_key = target_key + ((gate_hard - gate).detach() + gate) * self.key_weight * projected_key
        # output_value = target_value + ((gate_hard - gate).detach() + gate) * self.value_weight * projected_value

        normalized_key_weight = torch.sigmoid(self.key_weight_head(weight_hidden) / self.scalar_temperature)
        normalized_value_weight = torch.sigmoid(self.value_weight_head(weight_hidden) / self.scalar_temperature)
        output_key = (1 - normalized_key_weight) * target_key.clone() + gate * normalized_key_weight * projected_key
        output_value = (1 - normalized_value_weight) * target_value.clone() + gate * normalized_value_weight * projected_value
        # output_key = target_key + self.key_weight * projected_key
        # output_value = target_value + self.value_weight * projected_value
        return (output_key, output_value)


@register_model
@capture_init_args
class ConcatTokenWiseGatedProjector(Projector):
    """
    Gated projector that projects source key-value tensors to target dimension using MLP,
    then adds them to target tensors using learnable weights.
    """
    
    def __init__(
        self,
        source_dim: int,
        target_dim: int,
        source_num_heads: int,
        target_num_heads: int,
        hidden_dim: int = 512,
        num_layers: int = 2,
        dropout: float = 0.1,
        activation: str = "gelu",
        use_layer_norm: bool = True,
        init_type: str = "zero", # 'zero', 'ones', 'uniform', 'normal'
        # shared_key_projection: nn.Module = None,
        # shared_value_projection: nn.Module = None,
        anneal_steps = 1360,
        structure_type: str = "channelwise",  # 'channelwise', or 'layerwise'
        dtype: torch.dtype = torch.float32
    ):
        super().__init__()
        
        self.source_dim = source_dim
        self.target_dim = target_dim
        self.source_num_heads = source_num_heads
        self.target_num_heads = target_num_heads
        self.hidden_dim = hidden_dim
        self.num_layers = num_layers
        self.structure_type = structure_type
        
        # Activation function
        if activation.lower() == "gelu":
            self.activation = nn.GELU()
        elif activation.lower() == "relu":
            self.activation = nn.ReLU()
        elif activation.lower() == "silu":
            self.activation = nn.SiLU()
        else:
            raise ValueError(f"Unsupported activation: {activation}")
        
        # Build separate MLP layers for key and value projection
        self.key_projection = self._build_mlp(source_dim * source_num_heads, hidden_dim, target_dim * target_num_heads, num_layers, use_layer_norm, dropout, dtype)
        self.value_projection = self._build_mlp(source_dim * source_num_heads, hidden_dim, target_dim * target_num_heads, num_layers, use_layer_norm, dropout, dtype)

        self.key_combiner = nn.Linear(target_dim * 2, target_dim, dtype=dtype)
        self.value_combiner = nn.Linear(target_dim * 2, target_dim, dtype=dtype)
        
        # self.key_projection = shared_key_projection
        # self.value_projection = shared_value_projection
        
        # -------------------------------------------
        # Learnable weights for additive combination (separate for key and value)
        # self.key_weight = nn.Parameter(torch.tensor(init_weight, dtype=dtype))
        # self.value_weight = nn.Parameter(torch.tensor(init_weight, dtype=dtype))

        # self.gate_logit = nn.Parameter(torch.tensor(0.0, dtype=dtype))
        # -------------------------------------------

        # Per-token dynamic weight generators (shared hidden layers)
        self.weight_hidden = nn.Sequential(
            nn.Linear(target_dim, hidden_dim, dtype=dtype),
            self.activation,
            nn.Dropout(dropout)
        )

        if structure_type == 'channelwise':
            # Separate heads for key and value weights
            self.key_weight_head = nn.Linear(hidden_dim, target_dim, dtype=dtype)
            self.value_weight_head = nn.Linear(hidden_dim, target_dim, dtype=dtype)
            # Per-token gate generator
            self.gate_generator = nn.Sequential(
                nn.Linear(target_dim, hidden_dim, dtype=dtype),
                self.activation,
                nn.Linear(hidden_dim, target_dim, dtype=dtype)
            )
        else:
            self.key_weight_head = nn.Linear(hidden_dim, 1, dtype=dtype)
            self.value_weight_head = nn.Linear(hidden_dim, 1, dtype=dtype)
            # Per-token gate generator
            self.gate_generator = nn.Sequential(
                nn.Linear(target_dim, hidden_dim, dtype=dtype),
                self.activation,
                nn.Linear(hidden_dim, 1, dtype=dtype)
            )

        if init_type == "zero":
            nn.init.zeros_(self.key_weight_head.weight)
            nn.init.zeros_(self.value_weight_head.weight)
        elif init_type == "ones":
            nn.init.ones_(self.key_weight_head.weight)
            nn.init.ones_(self.value_weight_head.weight)
        elif init_type == "uniform":
            nn.init.uniform_(self.key_weight_head.weight, a=0.0, b=1.0)
            nn.init.uniform_(self.value_weight_head.weight, a=0.0, b=1.0)
        elif init_type == "normal":
            nn.init.normal_(self.key_weight_head.weight, mean=0.0, std=1.0)
            nn.init.normal_(self.value_weight_head.weight, mean=0.0, std=1.0)  

        self.register_buffer("gate_temperature", torch.tensor(1.0))  # initial temperature
        self.initial_temperature = 1.0
        self.final_temperature = 0.01
        self.anneal_steps = anneal_steps
        self.scalar_temperature = 0.005
    
    def _build_mlp(self, source_dim: int, hidden_dim: int, target_dim: int, num_layers: int, 
                   use_layer_norm: bool, dropout: float, dtype: torch.dtype) -> nn.Sequential:
        """Build a single MLP projection module"""
        layers = []
        
        # Input projection
        layers.append(nn.Linear(source_dim, hidden_dim, dtype=dtype))
        if use_layer_norm:
            layers.append(nn.LayerNorm(hidden_dim, dtype=dtype))
        layers.append(copy.deepcopy(self.activation))
        layers.append(nn.Dropout(dropout))
        
        # Hidden layers
        for _ in range(num_layers - 2):
            layers.append(nn.Linear(hidden_dim, hidden_dim, dtype=dtype))
            if use_layer_norm:
                layers.append(nn.LayerNorm(hidden_dim, dtype=dtype))
            layers.append(copy.deepcopy(self.activation))
            layers.append(nn.Dropout(dropout))
        
        # Output projection
        if num_layers > 1:
            layers.append(nn.Linear(hidden_dim, target_dim, dtype=dtype))
        else:
            # Single layer case
            layers = [nn.Linear(source_dim, target_dim, dtype=dtype)]
        
        return nn.Sequential(*layers)

    def update_temperature(self, step: int):
        """
        Update the temperature using exponential annealing schedule
        Args:
            step: current training step
        """
        ratio = min(step / self.anneal_steps, 1.0)
        temp = self.initial_temperature * (self.final_temperature / self.initial_temperature) ** ratio
        self.gate_temperature.fill_(temp)

    @staticmethod
    def sample_gumbel(shape, eps=1e-20):
        U = torch.rand(shape)
        return -torch.log(-torch.log(U + eps) + eps)

    def gumbel_sigmoid_sample(self, gate_logit):
        g0 = self.sample_gumbel(gate_logit.size()).to(gate_logit.device)
        g1 = self.sample_gumbel(gate_logit.size()).to(gate_logit.device)
        y = torch.sigmoid((gate_logit + g1 - g0) / self.gate_temperature)
        return y

    
    def forward(self, source_kv: Tuple[Tensor, Tensor], target_kv: Tuple[Tensor, Tensor]) -> Tuple[Tensor, Tensor]:
        """
        Project source key-value tensors to target dimension and add to target tensors with learnable weights
        Args:
            source_kv: Tuple of (key, value) tensors, each (..., D_s) where ... are arbitrary leading dimensions
            target_kv: Tuple of (key, value) tensors, each (..., D_t) where ... are arbitrary leading dimensions
        Returns:
            Tuple of (key, value) tensors, each (..., D_t) with same leading dimensions as input
        """

        source_key, source_value = source_kv
        target_key, target_value = target_kv

        # Get shapes - assuming format is (B, H, N, D) where H is num_heads, N is seq_len, D is head_dim
        source_shape = source_key.shape  # (B, H_s, N, D_s)
        target_shape = target_key.shape  # (B, H_t, N, D_t)
        
        # Extract dimensions
        batch_size, source_num_heads, seq_len, source_head_dim = source_shape
        _, target_num_heads, _, target_head_dim = target_shape
        
        # Reshape source: merge num_heads and head_dim for projection
        # (B, H_s, N, D_s) -> (B, N, H_s * D_s)
        source_key_reshaped = source_key.transpose(1, 2)  # (B, N, H_s, D_s)
        source_value_reshaped = source_value.transpose(1, 2)  # (B, N, H_s, D_s)
        
        source_key_flat = source_key_reshaped.contiguous().view(batch_size, seq_len, source_num_heads * source_head_dim)
        source_value_flat = source_value_reshaped.contiguous().view(batch_size, seq_len, source_num_heads * source_head_dim)
        
        # Project source tensors from (H_s * D_s) to (H_t * D_t)
        projected_key_flat = self.key_projection(source_key_flat)  # (B, N, H_t * D_t)
        projected_value_flat = self.value_projection(source_value_flat)  # (B, N, H_t * D_t)
        
        target_key_flat = target_key.transpose(1, 2).contiguous().view(batch_size, seq_len, target_num_heads * target_head_dim)
        target_value_flat = target_value.transpose(1, 2).contiguous().view(batch_size, seq_len, target_num_heads * target_head_dim)

         # Combine with target tensors
        combined_key = torch.cat([projected_key_flat, target_key_flat], dim=-1)  # (..., 2*D_t)
        combined_value = torch.cat([projected_value_flat, target_value_flat], dim=-1)  # (..., 2*D_t)

        projected_combined_key = self.key_combiner(combined_key)  # (..., D_t)
        projected_combined_value = self.value_combiner(combined_value)  # (..., D_t)

        output_key = projected_combined_key.view(batch_size, seq_len, target_num_heads, target_head_dim).transpose(1, 2)  # (B, H_t, N, D_t)
        output_value = projected_combined_value.view(batch_size, seq_len, target_num_heads, target_head_dim).transpose(1, 2)  # (B, H_t, N, D

        # current the weight only depend on slm key cache
        weight_hidden = self.weight_hidden(target_key_flat.clone())

        # Generate per-token gates
        gate_logit = self.gate_generator(target_key_flat.clone())

        if self.structure_type == 'channelwise':
            gate_logit = gate_logit.view(batch_size, seq_len, target_num_heads, target_head_dim).transpose(1,2)  # (batch, seq_len, 1)
        else:
            gate_logit = gate_logit.view(batch_size, seq_len, 1, 1).transpose(1,2)  # (batch, seq_len, 1)

        if self.training:
            # Gate 1: temperature-controlled sigmoid
            # gate = torch.sigmoid(self.gate_logit / self.gate_temperature)
            # Gate 2: gumbel-sigmoid
            gate = self.gumbel_sigmoid_sample(gate_logit)
            gate_hard = (gate > 0.5).float()
        else:
            gate = (gate_logit > 0).float()
            gate_hard = gate
        
        key_weight = self.key_weight_head(weight_hidden)
        value_weight = self.value_weight_head(weight_hidden)
        if self.structure_type == 'channelwise':
            key_weight = key_weight.view(batch_size, seq_len, target_num_heads, target_head_dim).transpose(1,2)
            value_weight = value_weight.view(batch_size, seq_len, target_num_heads, target_head_dim).transpose(1,2)
        else:
            key_weight = key_weight.view(batch_size, seq_len, 1, 1).transpose(1,2)
            value_weight = value_weight.view(batch_size, seq_len, 1, 1).transpose(1,2)


        normalized_key_weight = torch.sigmoid(key_weight / self.scalar_temperature)
        normalized_value_weight = torch.sigmoid(value_weight / self.scalar_temperature)
        output_key_result = (1 - normalized_key_weight) * target_key.clone() + gate * normalized_key_weight * output_key
        output_value_result = (1 - normalized_value_weight) * target_value.clone() + gate * normalized_value_weight * output_value
        # output_key = target_key + self.key_weight * projected_key
        # output_value = target_value + self.value_weight * projected_value
        return (output_key_result, output_value_result)
    

@register_model
@capture_init_args
class FFNProjector(Projector):
    """
    Feed-Forward Network (FFN) projector that projects source key-value tensors to target dimension using MLP and SwiGLU.
    """
    
    def __init__(
        self,
        source_dim: int,
        target_dim: int,
        hidden_dim: int = 512,
        num_layers: int = 2,
        dropout: float = 0.1,
        use_layer_norm: bool = True,
        init_weight: float = 0.1,
        dtype: torch.dtype = torch.float32
    ):
        super().__init__()
        
        self.source_dim = source_dim
        self.target_dim = target_dim
        self.hidden_dim = hidden_dim
        self.num_layers = num_layers
        
        # Build separate MLP layers for key and value projection
        self.key_projection = self._build_mlp(source_dim, hidden_dim, target_dim, num_layers, use_layer_norm, dropout, dtype)
        self.value_projection = self._build_mlp(source_dim, hidden_dim, target_dim, num_layers, use_layer_norm, dropout, dtype)
        
        # Learnable weights for additive combination (separate for key and value)
        self.key_weight = nn.Parameter(torch.tensor(init_weight, dtype=dtype))
        self.value_weight = nn.Parameter(torch.tensor(init_weight, dtype=dtype))
    
    def _build_mlp(self, source_dim: int, hidden_dim: int, target_dim: int, num_layers: int, 
                   use_layer_norm: bool, dropout: float, dtype: torch.dtype) -> nn.Sequential:
        """Build a single MLP projection module"""
        layers = []
        
        # Input projection
        layers.append(SwigLUBlock(source_dim, hidden_dim, dtype=dtype))
        if use_layer_norm:
            layers.append(nn.LayerNorm(hidden_dim, dtype=dtype))
        layers.append(nn.Dropout(dropout))
        
        # Hidden layers
        for _ in range(num_layers - 2):
            layers.append(SwigLUBlock(hidden_dim, hidden_dim, dtype=dtype))
            if use_layer_norm:
                layers.append(nn.LayerNorm(hidden_dim, dtype=dtype))
            layers.append(nn.Dropout(dropout))
        
        # Output projection
        if num_layers > 1:
            layers.append(SwigLUBlock(hidden_dim, hidden_dim, dtype=dtype))
            if use_layer_norm:
                layers.append(nn.LayerNorm(hidden_dim, dtype=dtype))
            layers.append(nn.Dropout(dropout))
            layers.append(nn.Linear(hidden_dim, target_dim, dtype=dtype))
        else:
            layers.append(nn.Linear(hidden_dim, target_dim, dtype=dtype))
        
        return nn.Sequential(*layers)
    
    def forward(self, source_kv: Tuple[Tensor, Tensor], target_kv: Tuple[Tensor, Tensor]) -> Tuple[Tensor, Tensor]:
        """
        Project source key-value tensors to target dimension and add to target tensors with learnable weights
        Args:
            source_kv: Tuple of (key, value) tensors, each (..., D_s) where ... are arbitrary leading dimensions
            target_kv: Tuple of (key, value) tensors, each (..., D_t) where ... are arbitrary leading dimensions
        Returns:
            Tuple of (key, value) tensors, each (..., D_t) with same leading dimensions as input
        """
        source_key, source_value = source_kv
        target_key, target_value = target_kv
        
        # Project source tensors to target dimension
        projected_key = self.key_projection(source_key)  # (..., D_t)
        projected_value = self.value_projection(source_value)  # (..., D_t)
        
        # Add projected source to target with learnable weights
        output_key = target_key + self.key_weight * projected_key
        output_value = target_value + self.value_weight * projected_value
        
        return (output_key, output_value)
    

@register_model
@capture_init_args
class OldTransformerProjector(Projector):
    """
    Transformer-based projector using PyTorch transformer layers for sophisticated projection
    """
    
    def __init__(
        self,
        source_dim: int,
        target_dim: int,
        hidden_size: int = 512,
        num_attention_heads: int = 8,
        num_layers: int = 2,
        feedforward_dim: Optional[int] = None,
        dropout: float = 0.1,
        activation: str = "gelu",
        dtype: torch.dtype = torch.float32
    ):
        super().__init__()

        raise NotImplementedError("TransformerProjector is not fully tested yet")
        
        self.source_dim = source_dim
        self.target_dim = target_dim
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        
        if feedforward_dim is None:
            feedforward_dim = hidden_size * 4
        
        # Input projections (separate for key and value)
        self.source_key_projection = nn.Linear(source_dim, hidden_size, dtype=dtype)
        self.source_value_projection = nn.Linear(source_dim, hidden_size, dtype=dtype)
        self.target_key_projection = nn.Linear(target_dim, hidden_size, dtype=dtype)
        self.target_value_projection = nn.Linear(target_dim, hidden_size, dtype=dtype)
        
        # Positional encoding
        self.positional_encoding = PositionalEncoding(hidden_size, dropout, dtype=dtype)
        
        # Create transformer decoder layers (separate for key and value)
        key_decoder_layer = nn.TransformerDecoderLayer(
            d_model=hidden_size,
            nhead=num_attention_heads,
            dim_feedforward=feedforward_dim,
            dropout=dropout,
            activation=activation,
            batch_first=True,
            dtype=dtype
        )
        
        value_decoder_layer = nn.TransformerDecoderLayer(
            d_model=hidden_size,
            nhead=num_attention_heads,
            dim_feedforward=feedforward_dim,
            dropout=dropout,
            activation=activation,
            batch_first=True,
            dtype=dtype
        )
        
        # Build transformer decoders
        self.key_transformer_decoder = nn.TransformerDecoder(
            key_decoder_layer, num_layers=num_layers
        )
        self.value_transformer_decoder = nn.TransformerDecoder(
            value_decoder_layer, num_layers=num_layers
        )
        
        # Output projections (separate for key and value)
        self.key_output_projection = nn.Linear(hidden_size, target_dim, dtype=dtype)
        self.value_output_projection = nn.Linear(hidden_size, target_dim, dtype=dtype)
        
        # Layer norm for final output (separate for key and value)
        self.key_final_layer_norm = nn.LayerNorm(target_dim, dtype=dtype)
        self.value_final_layer_norm = nn.LayerNorm(target_dim, dtype=dtype)
        
        # Optional residual connection
        self.use_residual = (source_dim == target_dim)
        if self.use_residual:
            self.key_residual_projection = nn.Linear(source_dim, target_dim, dtype=dtype)
            self.value_residual_projection = nn.Linear(source_dim, target_dim, dtype=dtype)
    
    def forward(self, source_kv: Tuple[Tensor, Tensor], target_kv: Tuple[Tensor, Tensor]) -> Tuple[Tensor, Tensor]:
        """
        Project and combine the source key-value tensors to the target key-value tensors using transformer layers
        Args:
            source_kv: Tuple of (key, value) tensors, each (..., D_s) where ... are arbitrary leading dimensions  
            target_kv: Tuple of (key, value) tensors, each (..., D_t) where ... are arbitrary leading dimensions
        Returns:
            Tuple of (key, value) tensors, each (..., D_t) with same leading dimensions as input
        """
        source_key, source_value = source_kv
        target_key, target_value = target_kv
        
        B, N, H, D_s = source_key.shape
        _, _, _, D_t = target_key.shape
        
        # Project inputs to hidden size (separately for key and value)
        source_key_proj = self.source_key_projection(source_key)  # (B, N, H, hidden_size)
        source_value_proj = self.source_value_projection(source_value)  # (B, N, H, hidden_size)
        target_key_proj = self.target_key_projection(target_key)  # (B, N, H, hidden_size)
        target_value_proj = self.target_value_projection(target_value)  # (B, N, H, hidden_size)
        
        # Flatten to sequence dimension for transformer
        # Treat each (N, H) as a sequence element
        source_key_seq = source_key_proj.view(B, N * H, self.hidden_size)  # (B, N*H, hidden_size)
        source_value_seq = source_value_proj.view(B, N * H, self.hidden_size)  # (B, N*H, hidden_size)
        target_key_seq = target_key_proj.view(B, N * H, self.hidden_size)  # (B, N*H, hidden_size)
        target_value_seq = target_value_proj.view(B, N * H, self.hidden_size)  # (B, N*H, hidden_size)
        
        # Add positional encoding
        source_key_seq = self.positional_encoding(source_key_seq)
        source_value_seq = self.positional_encoding(source_value_seq)
        target_key_seq = self.positional_encoding(target_key_seq)
        target_value_seq = self.positional_encoding(target_value_seq)
        
        # Use transformer decoder: target attends to source (separately for key and value)
        # Memory (key/value) is source, target is query
        transformed_key = self.key_transformer_decoder(
            tgt=target_key_seq,    # Query sequences
            memory=source_key_seq  # Key/Value sequences
        )  # (B, N*H, hidden_size)
        
        transformed_value = self.value_transformer_decoder(
            tgt=target_value_seq,    # Query sequences
            memory=source_value_seq  # Key/Value sequences
        )  # (B, N*H, hidden_size)
        
        # Reshape back to original structure
        transformed_key = transformed_key.view(B, N, H, self.hidden_size)  # (B, N, H, hidden_size)
        transformed_value = transformed_value.view(B, N, H, self.hidden_size)  # (B, N, H, hidden_size)
        
        # Project to target dimension
        output_key = self.key_output_projection(transformed_key)  # (B, N, H, D_t)
        output_value = self.value_output_projection(transformed_value)  # (B, N, H, D_t)
        
        output_key = self.key_final_layer_norm(output_key)
        output_value = self.value_final_layer_norm(output_value)
        
        # Add residual connection if dimensions match
        if self.use_residual:
            key_residual = self.key_residual_projection(source_key)
            value_residual = self.value_residual_projection(source_value)
            output_key = output_key + key_residual + target_key
            output_value = output_value + value_residual + target_value
        else:
            output_key = output_key + target_key
            output_value = output_value + target_value
        
        return (output_key, output_value)


@register_model
@capture_init_args
class TransformerProjector(Projector):
    """
    Transformer-based projector using PyTorch transformer layers for sophisticated projection
    """
    
    def __init__(
        self,
        source_dim: int,
        target_dim: int,
        hidden_dim: int = 512,
        num_attention_heads: int = 8,
        num_layers: int = 2,
        feedforward_dim: Optional[int] = None,
        dropout: float = 0.1,
        activation: str = "gelu",
        dtype: torch.dtype = torch.float32
    ):
        super().__init__()
        
        self.source_dim = source_dim
        self.target_dim = target_dim
        self.hidden_size = hidden_dim
        self.num_layers = num_layers
        
        if feedforward_dim is None:
            feedforward_dim = hidden_dim * 4
        
        # Input projections (separate for key and value)
        self.source_key_projection = nn.Linear(source_dim, hidden_dim, dtype=dtype)
        self.source_value_projection = nn.Linear(source_dim, hidden_dim, dtype=dtype)
        self.target_key_projection = nn.Linear(target_dim, hidden_dim, dtype=dtype)
        self.target_value_projection = nn.Linear(target_dim, hidden_dim, dtype=dtype)

        nn.init.zeros_(self.source_key_projection.weight)
        nn.init.zeros_(self.source_value_projection.weight)
        nn.init.zeros_(self.target_key_projection.weight)
        nn.init.zeros_(self.target_value_projection.weight)

        # Positional encoding
        self.positional_encoding = PositionalEncoding(hidden_dim, dropout, max_len=32768, dtype=dtype)
        
        # Create transformer decoder layers (separate for key and value)
        key_decoder_layer = nn.TransformerDecoderLayer(
            d_model=hidden_dim,
            nhead=num_attention_heads,
            dim_feedforward=feedforward_dim,
            dropout=dropout,
            activation=activation,
            batch_first=True,
            dtype=dtype
        )
        
        value_decoder_layer = nn.TransformerDecoderLayer(
            d_model=hidden_dim,
            nhead=num_attention_heads,
            dim_feedforward=feedforward_dim,
            dropout=dropout,
            activation=activation,
            batch_first=True,
            dtype=dtype
        )
        
        # Build transformer decoders
        self.key_transformer_decoder = nn.TransformerDecoder(
            key_decoder_layer, num_layers=num_layers
        )
        self.value_transformer_decoder = nn.TransformerDecoder(
            value_decoder_layer, num_layers=num_layers
        )
        
        # Output projections (separate for key and value)
        self.key_output_projection = nn.Linear(hidden_dim, target_dim, dtype=dtype)
        self.value_output_projection = nn.Linear(hidden_dim, target_dim, dtype=dtype)
        
        # Layer norm for final output (separate for key and value)
        self.key_final_layer_norm = nn.LayerNorm(target_dim, dtype=dtype)
        self.value_final_layer_norm = nn.LayerNorm(target_dim, dtype=dtype)
        
        # Optional residual connection
        # self.use_residual = (source_dim == target_dim)
        # if self.use_residual:
        #     self.key_residual_projection = nn.Linear(source_dim, target_dim, dtype=dtype)
        #     self.value_residual_projection = nn.Linear(source_dim, target_dim, dtype=dtype)

    def update_temperature(self, step: int):
        """
        Update the temperature using exponential annealing schedule
        Args:
            step: current training step
        """
        ratio = min(step / self.anneal_steps, 1.0)
        temp = self.initial_temperature * (self.final_temperature / self.initial_temperature) ** ratio
        self.gate_temperature.fill_(temp)

    @staticmethod
    def sample_gumbel(shape, device, dtype, eps=1e-20):
        U = torch.rand(shape, device=device, dtype=dtype)
        return -torch.log(-torch.log(U + eps) + eps)

    def gumbel_sigmoid_sample(self):
        dev = self.gate_logit.device
        dt = self.gate_logit.dtype
        g0 = self.sample_gumbel(self.gate_logit.size(), dev, dt)
        g1 = self.sample_gumbel(self.gate_logit.size(), dev, dt)
        y = torch.sigmoid((self.gate_logit + g1 - g0) / self.gate_temperature)
        return y
    
    def forward(self, source_kv: Tuple[Tensor, Tensor], target_kv: Tuple[Tensor, Tensor]) -> Tuple[Tensor, Tensor]:
        """
        Project and combine the source key-value tensors to the target key-value tensors using transformer layers
        Args:
            source_kv: Tuple of (key, value) tensors, each (..., D_s) where ... are arbitrary leading dimensions  
            target_kv: Tuple of (key, value) tensors, each (..., D_t) where ... are arbitrary leading dimensions
        Returns:
            Tuple of (key, value) tensors, each (..., D_t) with same leading dimensions as input
        """
        source_key, source_value = source_kv
        target_key, target_value = target_kv
        
        B, H, N, D_s = source_key.shape
        _, _, _, D_t = target_key.shape

        source_key_reshaped = source_key.permute(2,0,1,3).contiguous().view(N, B, H*D_s)
        source_value_reshaped = source_value.permute(2,0,1,3).contiguous().view(N, B, H*D_s)
        target_key_reshaped = target_key.permute(2,0,1,3).contiguous().view(N, B, H*D_s)
        target_value_reshaped = target_value.permute(2,0,1,3).contiguous().view(N, B, H*D_s)

        # Project inputs to hidden size (separately for key and value)
        source_key_proj = self.source_key_projection(source_key_reshaped)  # (N, B, hidden_size)
        source_value_proj = self.source_value_projection(source_value_reshaped)  # (N, B,hidden_size)
        target_key_proj = self.target_key_projection(target_key_reshaped)  # (N, B, hidden_size)
        target_value_proj = self.target_value_projection(target_value_reshaped)  # (N, B, hidden_size)
        
        # Flatten to sequence dimension for transformer
        # Treat each (N, H) as a sequence element
        # source_key_seq = source_key_proj.view(B, N, self.hidden_size)  # (B, N*H, hidden_size)
        # source_value_seq = source_value_proj.view(B, N * H, self.hidden_size)  # (B, N*H, hidden_size)
        # target_key_seq = target_key_proj.view(B, N * H, self.hidden_size)  # (B, N*H, hidden_size)
        # target_value_seq = target_value_proj.view(B, N * H, self.hidden_size)  # (B, N*H, hidden_size)
        
        # Add positional encoding
        source_key_seq = self.positional_encoding(source_key_proj)
        source_value_seq = self.positional_encoding(source_value_proj)
        target_key_seq = self.positional_encoding(target_key_proj)
        target_value_seq = self.positional_encoding(target_value_proj)
        
        # Use transformer decoder: target attends to source (separately for key and value)
        # Memory (key/value) is source, target is query
        transformed_key = self.key_transformer_decoder(
            tgt=target_key_seq,    # Query sequences
            memory=source_key_seq  # Key/Value sequences
        )  # (B, N*H, hidden_size)
        
        transformed_value = self.value_transformer_decoder(
            tgt=target_value_seq,    # Query sequences
            memory=source_value_seq  # Key/Value sequences
        )  # (B, N*H, hidden_size)
        
        # # Reshape back to original structure
        # transformed_key = transformed_key.view(N, B, H, self.hidden_size)  # (N, B, H, hidden_size)
        # transformed_value = transformed_value.view(N, B, H, self.hidden_size)  # (N, B, H, hidden_size)

        # Project to target dimension
        output_key = self.key_output_projection(transformed_key)  # (N, B, H * D_t)
        output_value = self.value_output_projection(transformed_value)  # (N, B, H * D_t)

        output_key = self.key_final_layer_norm(output_key)
        output_value = self.value_final_layer_norm(output_value)

        output_key = output_key.view(N, B, H, D_t).permute(1,2,0,3).contiguous()  # (B, H, N, D_t)
        output_value = output_value.view(N, B, H, D_t).permute(1,2,0,3).contiguous()  # (B, H, N, D_t)

        # Add residual connection if dimensions match
        # if self.use_residual:
        #     key_residual = self.key_residual_projection(source_key)
        #     value_residual = self.value_residual_projection(source_value)
        #     output_key = output_key + key_residual + target_key
        #     output_value = output_value + value_residual + target_value
        # else:
        #     output_key = output_key + target_key
        #     output_value = output_value + target_value

        output_key = output_key + target_key
        output_value = output_value + target_value
        
        return (output_key, output_value)


@register_model
@capture_init_args
class GatedTransformerProjector(Projector):
    """
    Transformer-based projector using PyTorch transformer layers for sophisticated projection
    """
    
    def __init__(
        self,
        source_dim: int,
        target_dim: int,
        hidden_dim: int = 512,
        num_attention_heads: int = 8,
        num_layers: int = 2,
        feedforward_dim: Optional[int] = None,
        dropout: float = 0.1,
        activation: str = "gelu",
        dtype: torch.dtype = torch.float32
    ):
        super().__init__()
        
        self.source_dim = source_dim
        self.target_dim = target_dim
        self.hidden_size = hidden_dim
        self.num_layers = num_layers
        
        if feedforward_dim is None:
            feedforward_dim = hidden_dim * 4
        
        # Input projections (separate for key and value)
        self.source_key_projection = nn.Linear(source_dim, hidden_dim, dtype=dtype)
        self.source_value_projection = nn.Linear(source_dim, hidden_dim, dtype=dtype)
        self.target_key_projection = nn.Linear(target_dim, hidden_dim, dtype=dtype)
        self.target_value_projection = nn.Linear(target_dim, hidden_dim, dtype=dtype)

        nn.init.zeros_(self.source_key_projection.weight)
        nn.init.zeros_(self.source_value_projection.weight)
        nn.init.zeros_(self.target_key_projection.weight)
        nn.init.zeros_(self.target_value_projection.weight)

        # Positional encoding
        self.positional_encoding = PositionalEncoding(hidden_dim, dropout, max_len=32768, dtype=dtype)
        
        # Create transformer decoder layers (separate for key and value)
        key_decoder_layer = nn.TransformerDecoderLayer(
            d_model=hidden_dim,
            nhead=num_attention_heads,
            dim_feedforward=feedforward_dim,
            dropout=dropout,
            activation=activation,
            batch_first=True,
            dtype=dtype
        )
        
        value_decoder_layer = nn.TransformerDecoderLayer(
            d_model=hidden_dim,
            nhead=num_attention_heads,
            dim_feedforward=feedforward_dim,
            dropout=dropout,
            activation=activation,
            batch_first=True,
            dtype=dtype
        )
        
        # Build transformer decoders
        self.key_transformer_decoder = nn.TransformerDecoder(
            key_decoder_layer, num_layers=num_layers
        )
        self.value_transformer_decoder = nn.TransformerDecoder(
            value_decoder_layer, num_layers=num_layers
        )
        
        # Output projections (separate for key and value)
        self.key_output_projection = nn.Linear(hidden_dim, target_dim, dtype=dtype)
        self.value_output_projection = nn.Linear(hidden_dim, target_dim, dtype=dtype)
        
        # Layer norm for final output (separate for key and value)
        self.key_final_layer_norm = nn.LayerNorm(target_dim, dtype=dtype)
        self.value_final_layer_norm = nn.LayerNorm(target_dim, dtype=dtype)
        
        # Optional residual connection
        # self.use_residual = (source_dim == target_dim)
        # if self.use_residual:
        #     self.key_residual_projection = nn.Linear(source_dim, target_dim, dtype=dtype)
        #     self.value_residual_projection = nn.Linear(source_dim, target_dim, dtype=dtype)

    def update_temperature(self, step: int):
        """
        Update the temperature using exponential annealing schedule
        Args:
            step: current training step
        """
        ratio = min(step / self.anneal_steps, 1.0)
        temp = self.initial_temperature * (self.final_temperature / self.initial_temperature) ** ratio
        self.gate_temperature.fill_(temp)

    @staticmethod
    def sample_gumbel(shape, device, dtype, eps=1e-20):
        U = torch.rand(shape, device=device, dtype=dtype)
        return -torch.log(-torch.log(U + eps) + eps)

    def gumbel_sigmoid_sample(self):
        dev = self.gate_logit.device
        dt = self.gate_logit.dtype
        g0 = self.sample_gumbel(self.gate_logit.size(), dev, dt)
        g1 = self.sample_gumbel(self.gate_logit.size(), dev, dt)
        y = torch.sigmoid((self.gate_logit + g1 - g0) / self.gate_temperature)
        return y
    
    def forward(self, source_kv: Tuple[Tensor, Tensor], target_kv: Tuple[Tensor, Tensor]) -> Tuple[Tensor, Tensor]:
        """
        Project and combine the source key-value tensors to the target key-value tensors using transformer layers
        Args:
            source_kv: Tuple of (key, value) tensors, each (..., D_s) where ... are arbitrary leading dimensions  
            target_kv: Tuple of (key, value) tensors, each (..., D_t) where ... are arbitrary leading dimensions
        Returns:
            Tuple of (key, value) tensors, each (..., D_t) with same leading dimensions as input
        """
        source_key, source_value = source_kv
        target_key, target_value = target_kv
        
        B, H, N, D_s = source_key.shape
        _, _, _, D_t = target_key.shape

        source_key_reshaped = source_key.permute(2,0,1,3).contiguous().view(N, B, H*D_s)
        source_value_reshaped = source_value.permute(2,0,1,3).contiguous().view(N, B, H*D_s)
        target_key_reshaped = target_key.permute(2,0,1,3).contiguous().view(N, B, H*D_s)
        target_value_reshaped = target_value.permute(2,0,1,3).contiguous().view(N, B, H*D_s)

        # Project inputs to hidden size (separately for key and value)
        source_key_proj = self.source_key_projection(source_key_reshaped)  # (N, B, hidden_size)
        source_value_proj = self.source_value_projection(source_value_reshaped)  # (N, B,hidden_size)
        target_key_proj = self.target_key_projection(target_key_reshaped)  # (N, B, hidden_size)
        target_value_proj = self.target_value_projection(target_value_reshaped)  # (N, B, hidden_size)
        
        # Flatten to sequence dimension for transformer
        # Treat each (N, H) as a sequence element
        # source_key_seq = source_key_proj.view(B, N, self.hidden_size)  # (B, N*H, hidden_size)
        # source_value_seq = source_value_proj.view(B, N * H, self.hidden_size)  # (B, N*H, hidden_size)
        # target_key_seq = target_key_proj.view(B, N * H, self.hidden_size)  # (B, N*H, hidden_size)
        # target_value_seq = target_value_proj.view(B, N * H, self.hidden_size)  # (B, N*H, hidden_size)
        
        # Add positional encoding
        source_key_seq = self.positional_encoding(source_key_proj)
        source_value_seq = self.positional_encoding(source_value_proj)
        target_key_seq = self.positional_encoding(target_key_proj)
        target_value_seq = self.positional_encoding(target_value_proj)
        
        # Use transformer decoder: target attends to source (separately for key and value)
        # Memory (key/value) is source, target is query
        transformed_key = self.key_transformer_decoder(
            tgt=target_key_seq,    # Query sequences
            memory=source_key_seq  # Key/Value sequences
        )  # (B, N*H, hidden_size)
        
        transformed_value = self.value_transformer_decoder(
            tgt=target_value_seq,    # Query sequences
            memory=source_value_seq  # Key/Value sequences
        )  # (B, N*H, hidden_size)
        
        # # Reshape back to original structure
        # transformed_key = transformed_key.view(N, B, H, self.hidden_size)  # (N, B, H, hidden_size)
        # transformed_value = transformed_value.view(N, B, H, self.hidden_size)  # (N, B, H, hidden_size)

        # Project to target dimension
        output_key = self.key_output_projection(transformed_key)  # (N, B, H * D_t)
        output_value = self.value_output_projection(transformed_value)  # (N, B, H * D_t)

        output_key = self.key_final_layer_norm(output_key)
        output_value = self.value_final_layer_norm(output_value)

        output_key = output_key.view(N, B, H, D_t).permute(1,2,0,3).contiguous()  # (B, H, N, D_t)
        output_value = output_value.view(N, B, H, D_t).permute(1,2,0,3).contiguous()  # (B, H, N, D_t)

        # Add residual connection if dimensions match
        # if self.use_residual:
        #     key_residual = self.key_residual_projection(source_key)
        #     value_residual = self.value_residual_projection(source_value)
        #     output_key = output_key + key_residual + target_key
        #     output_value = output_value + value_residual + target_value
        # else:
        #     output_key = output_key + target_key
        #     output_value = output_value + target_value

        output_key = output_key + target_key
        output_value = output_value + target_value
        
        return (output_key, output_value)


@register_model
@capture_init_args
class GatedTransformerProjector(Projector):
    """
    Transformer-based projector using PyTorch transformer layers for sophisticated projection
    """
    
    def __init__(
        self,
        source_dim: int,
        target_dim: int,
        hidden_dim: int = 512,
        num_attention_heads: int = 8,
        num_layers: int = 2,
        feedforward_dim: Optional[int] = None,
        dropout: float = 0.1,
        activation: str = "gelu",
        dtype: torch.dtype = torch.float32
    ):
        super().__init__()
        
        self.source_dim = source_dim
        self.target_dim = target_dim
        self.hidden_size = hidden_dim
        self.num_layers = num_layers
        
        if feedforward_dim is None:
            feedforward_dim = hidden_dim * 4
        
        # Input projections (separate for key and value)
        self.source_key_projection = nn.Linear(source_dim, hidden_dim, dtype=dtype)
        self.source_value_projection = nn.Linear(source_dim, hidden_dim, dtype=dtype)
        self.target_key_projection = nn.Linear(target_dim, hidden_dim, dtype=dtype)
        self.target_value_projection = nn.Linear(target_dim, hidden_dim, dtype=dtype)

        nn.init.zeros_(self.source_key_projection.weight)
        nn.init.zeros_(self.source_value_projection.weight)
        nn.init.zeros_(self.target_key_projection.weight)
        nn.init.zeros_(self.target_value_projection.weight)

        # Positional encoding
        self.positional_encoding = PositionalEncoding(hidden_dim, dropout, max_len=32768, dtype=dtype)
        
        # Create transformer decoder layers (separate for key and value)
        key_decoder_layer = nn.TransformerDecoderLayer(
            d_model=hidden_dim,
            nhead=num_attention_heads,
            dim_feedforward=feedforward_dim,
            dropout=dropout,
            activation=activation,
            batch_first=True,
            dtype=dtype
        )
        
        value_decoder_layer = nn.TransformerDecoderLayer(
            d_model=hidden_dim,
            nhead=num_attention_heads,
            dim_feedforward=feedforward_dim,
            dropout=dropout,
            activation=activation,
            batch_first=True,
            dtype=dtype
        )
        
        # Build transformer decoders
        self.key_transformer_decoder = nn.TransformerDecoder(
            key_decoder_layer, num_layers=num_layers
        )
        self.value_transformer_decoder = nn.TransformerDecoder(
            value_decoder_layer, num_layers=num_layers
        )
        
        # Output projections (separate for key and value)
        self.key_output_projection = nn.Linear(hidden_dim, target_dim, dtype=dtype)
        self.value_output_projection = nn.Linear(hidden_dim, target_dim, dtype=dtype)
        
        # Layer norm for final output (separate for key and value)
        self.key_final_layer_norm = nn.LayerNorm(target_dim, dtype=dtype)
        self.value_final_layer_norm = nn.LayerNorm(target_dim, dtype=dtype)
        
        # Optional residual connection
        # self.use_residual = (source_dim == target_dim)
        # if self.use_residual:
        #     self.key_residual_projection = nn.Linear(source_dim, target_dim, dtype=dtype)
        #     self.value_residual_projection = nn.Linear(source_dim, target_dim, dtype=dtype)
    
    def forward(self, source_kv: Tuple[Tensor, Tensor], target_kv: Tuple[Tensor, Tensor]) -> Tuple[Tensor, Tensor]:
        """
        Project and combine the source key-value tensors to the target key-value tensors using transformer layers
        Args:
            source_kv: Tuple of (key, value) tensors, each (..., D_s) where ... are arbitrary leading dimensions  
            target_kv: Tuple of (key, value) tensors, each (..., D_t) where ... are arbitrary leading dimensions
        Returns:
            Tuple of (key, value) tensors, each (..., D_t) with same leading dimensions as input
        """
        source_key, source_value = source_kv
        target_key, target_value = target_kv
        
        B, H, N, D_s = source_key.shape
        _, _, _, D_t = target_key.shape

        source_key_reshaped = source_key.permute(2,0,1,3).contiguous().view(N, B, H*D_s)
        source_value_reshaped = source_value.permute(2,0,1,3).contiguous().view(N, B, H*D_s)
        target_key_reshaped = target_key.permute(2,0,1,3).contiguous().view(N, B, H*D_s)
        target_value_reshaped = target_value.permute(2,0,1,3).contiguous().view(N, B, H*D_s)

        # Project inputs to hidden size (separately for key and value)
        source_key_proj = self.source_key_projection(source_key_reshaped)  # (N, B, hidden_size)
        source_value_proj = self.source_value_projection(source_value_reshaped)  # (N, B,hidden_size)
        target_key_proj = self.target_key_projection(target_key_reshaped)  # (N, B, hidden_size)
        target_value_proj = self.target_value_projection(target_value_reshaped)  # (N, B, hidden_size)
        
        # Flatten to sequence dimension for transformer
        # Treat each (N, H) as a sequence element
        # source_key_seq = source_key_proj.view(B, N, self.hidden_size)  # (B, N*H, hidden_size)
        # source_value_seq = source_value_proj.view(B, N * H, self.hidden_size)  # (B, N*H, hidden_size)
        # target_key_seq = target_key_proj.view(B, N * H, self.hidden_size)  # (B, N*H, hidden_size)
        # target_value_seq = target_value_proj.view(B, N * H, self.hidden_size)  # (B, N*H, hidden_size)
        
        # Add positional encoding
        source_key_seq = self.positional_encoding(source_key_proj)
        source_value_seq = self.positional_encoding(source_value_proj)
        target_key_seq = self.positional_encoding(target_key_proj)
        target_value_seq = self.positional_encoding(target_value_proj)
        
        # Use transformer decoder: target attends to source (separately for key and value)
        # Memory (key/value) is source, target is query
        transformed_key = self.key_transformer_decoder(
            tgt=target_key_seq,    # Query sequences
            memory=source_key_seq  # Key/Value sequences
        )  # (B, N*H, hidden_size)
        
        transformed_value = self.value_transformer_decoder(
            tgt=target_value_seq,    # Query sequences
            memory=source_value_seq  # Key/Value sequences
        )  # (B, N*H, hidden_size)
        
        # # Reshape back to original structure
        # transformed_key = transformed_key.view(N, B, H, self.hidden_size)  # (N, B, H, hidden_size)
        # transformed_value = transformed_value.view(N, B, H, self.hidden_size)  # (N, B, H, hidden_size)

        # Project to target dimension
        output_key = self.key_output_projection(transformed_key)  # (N, B, H * D_t)
        output_value = self.value_output_projection(transformed_value)  # (N, B, H * D_t)

        output_key = self.key_final_layer_norm(output_key)
        output_value = self.value_final_layer_norm(output_value)

        output_key = output_key.view(N, B, H, D_t).permute(1,2,0,3).contiguous()  # (B, H, N, D_t)
        output_value = output_value.view(N, B, H, D_t).permute(1,2,0,3).contiguous()  # (B, H, N, D_t)

        # Add residual connection if dimensions match
        # if self.use_residual:
        #     key_residual = self.key_residual_projection(source_key)
        #     value_residual = self.value_residual_projection(source_value)
        #     output_key = output_key + key_residual + target_key
        #     output_value = output_value + value_residual + target_value
        # else:
        #     output_key = output_key + target_key
        #     output_value = output_value + target_value

        output_key = output_key + target_key
        output_value = output_value + target_value
        
        return (output_key, output_value)


class PositionalEncoding(nn.Module):
    """
    Positional encoding for transformer layers
    """
    
    def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000, dtype: torch.dtype = torch.float32):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)
        
        position = torch.arange(max_len, dtype=dtype).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2, dtype=dtype) * (-math.log(10000.0) / d_model))
        
        pe = torch.zeros(max_len, d_model, dtype=dtype)
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        
        self.register_buffer('pe', pe)
    
    def forward(self, x: Tensor) -> Tensor:
        """
        Args:
            x: Tensor, shape [batch_size, seq_len, embedding_dim]
        """
        seq_len = x.size(1)
        x = x + self.pe[:seq_len].unsqueeze(0)
        return self.dropout(x)


def create_projector(projector_type: str, **kwargs) -> Projector:
    """
    Factory function to create a projector based on type.
    
    Args:
        projector_type: String indicating the type of projector
        **kwargs: Additional arguments to pass to the projector constructor
        
    Returns:
        An instance of the appropriate projector
    """
    # Prefer using the unified registry getter (handles case-insensitive keys)
    try:
        cls = get_projector_class(projector_type)
    except ValueError as e:
        raise e
    return cls(**kwargs)


@register_model
@capture_init_args
class TokenWiseAdditiveProjector(Projector):
    """
    Simplified additive projector based on TokenWiseAdditiveProjector with:
    1. Simple gate logit parameter instead of linear network
    2. Sequence-level weights that depend on input instead of per-token weights
    """
    
    def __init__(
        self,
        source_dim: int,
        target_dim: int,
        source_num_heads: int,
        target_num_heads: int,
        hidden_dim: int = 512,
        num_layers: int = 2,
        dropout: float = 0.1,
        activation: str = "gelu",
        use_layer_norm: bool = True,
        init_weight: float = 0.1,
        anneal_steps: int = 1360,
        initial_temperature: float = 1.0,
        final_temperature: float = 0.01,
        scalar_temperature: float = 0.005,
        dtype: torch.dtype = torch.float32
    ):
        super().__init__()
        
        self.source_dim = source_dim
        self.target_dim = target_dim
        self.source_num_heads = source_num_heads
        self.target_num_heads = target_num_heads
        self.hidden_dim = hidden_dim
        self.num_layers = num_layers
        
        # Activation function
        if activation.lower() == "gelu":
            self.activation = nn.GELU()
        elif activation.lower() == "relu":
            self.activation = nn.ReLU()
        elif activation.lower() == "silu":
            self.activation = nn.SiLU()
        else:
            raise ValueError(f"Unsupported activation: {activation}")
        
        # Build separate MLP layers for key and value projection
        self.key_projection = self._build_mlp(source_dim * source_num_heads, hidden_dim, target_dim * target_num_heads, num_layers, use_layer_norm, dropout, dtype)
        self.value_projection = self._build_mlp(source_dim * source_num_heads, hidden_dim, target_dim * target_num_heads, num_layers, use_layer_norm, dropout, dtype)

        # Simple gate logit parameter (simplified from complex linear network)
        self.gate_logit = nn.Parameter(torch.tensor(0.0, dtype=dtype))
        
        # Shared hidden layer for weight generation (like TokenWise implementation)
        self.weight_hidden = nn.Sequential(
            nn.Linear(target_dim, hidden_dim, dtype=dtype),
            self.activation,
            nn.Dropout(dropout)
        )
        
        # Separate heads for key and value weights (like TokenWise implementation)
        self.key_weight_head = nn.Linear(hidden_dim, 1, dtype=dtype)
        self.value_weight_head = nn.Linear(hidden_dim, 1, dtype=dtype)
        self.register_buffer("gate_temperature", torch.tensor(1.0))  # initial temperature
        self.initial_temperature = initial_temperature
        self.final_temperature = final_temperature
        self.anneal_steps = anneal_steps
        self.scalar_temperature = scalar_temperature
    
    def _build_mlp(self, source_dim: int, hidden_dim: int, target_dim: int, num_layers: int, 
                   use_layer_norm: bool, dropout: float, dtype: torch.dtype) -> nn.Sequential:
        """Build a single MLP projection module"""
        layers = []
        
        # Input projection
        layers.append(nn.Linear(source_dim, hidden_dim, dtype=dtype))
        if use_layer_norm:
            layers.append(nn.LayerNorm(hidden_dim, dtype=dtype))
        layers.append(copy.deepcopy(self.activation))
        layers.append(nn.Dropout(dropout))
        
        # Hidden layers
        for _ in range(num_layers - 2):
            layers.append(nn.Linear(hidden_dim, hidden_dim, dtype=dtype))
            if use_layer_norm:
                layers.append(nn.LayerNorm(hidden_dim, dtype=dtype))
            layers.append(copy.deepcopy(self.activation))
            layers.append(nn.Dropout(dropout))
        
        # Output projection
        if num_layers > 1:
            layers.append(nn.Linear(hidden_dim, target_dim, dtype=dtype))
        else:
            # Single layer case
            layers = [nn.Linear(source_dim, target_dim, dtype=dtype)]
        
        return nn.Sequential(*layers)

    def update_temperature(self, step: int):
        """
        Update the temperature using exponential annealing schedule
        Args:
            step: current training step
        """
        ratio = min(step / self.anneal_steps, 1.0)
        temp = self.initial_temperature * (self.final_temperature / self.initial_temperature) ** ratio
        self.gate_temperature.fill_(temp)

    @staticmethod
    def sample_gumbel(shape, device, dtype, eps=1e-20):
        U = torch.rand(shape, device=device, dtype=dtype)
        return -torch.log(-torch.log(U + eps) + eps)

    def gumbel_sigmoid_sample(self):
        dev = self.gate_logit.device
        dt = self.gate_logit.dtype
        g0 = self.sample_gumbel(self.gate_logit.size(), dev, dt)
        g1 = self.sample_gumbel(self.gate_logit.size(), dev, dt)
        y = torch.sigmoid((self.gate_logit + g1 - g0) / self.gate_temperature)
        return y

    
    def forward(self, source_kv: Tuple[Tensor, Tensor], target_kv: Tuple[Tensor, Tensor]) -> Tuple[Tensor, Tensor]:
        """
        Project source key-value tensors to target dimension and add to target tensors with learnable weights
        Args:
            source_kv: Tuple of (key, value) tensors, each (..., D_s) where ... are arbitrary leading dimensions
            target_kv: Tuple of (key, value) tensors, each (..., D_t) where ... are arbitrary leading dimensions
        Returns:
            Tuple of (key, value) tensors, each (..., D_t) with same leading dimensions as input
        """

        source_key, source_value = source_kv
        target_key, target_value = target_kv
        
        # Get shapes - assuming format is (B, H, N, D) where H is num_heads, N is seq_len, D is head_dim
        source_shape = source_key.shape  # (B, H_s, N, D_s)
        target_shape = target_key.shape  # (B, H_t, N, D_t)
        
        # Extract dimensions
        batch_size, source_num_heads, seq_len, source_head_dim = source_shape
        _, target_num_heads, _, target_head_dim = target_shape
        
        # Reshape source: merge num_heads and head_dim for projection
        # (B, H_s, N, D_s) -> (B, N, H_s * D_s)
        source_key_reshaped = source_key.transpose(1, 2)  # (B, N, H_s, D_s)
        source_value_reshaped = source_value.transpose(1, 2)  # (B, N, H_s, D_s)
        
        source_key_flat = source_key_reshaped.contiguous().view(batch_size, seq_len, source_num_heads * source_head_dim)
        source_value_flat = source_value_reshaped.contiguous().view(batch_size, seq_len, source_num_heads * source_head_dim)
        
        # Project source tensors from (H_s * D_s) to (H_t * D_t)
        projected_key_flat = self.key_projection(source_key_flat)  # (B, N, H_t * D_t)
        projected_value_flat = self.value_projection(source_value_flat)  # (B, N, H_t * D_t)
        
        # Reshape projected tensors back to target format
        # (B, N, H_t * D_t) -> (B, N, H_t, D_t) -> (B, H_t, N, D_t)
        projected_key_reshaped = projected_key_flat.view(batch_size, seq_len, target_num_heads, target_head_dim)
        projected_value_reshaped = projected_value_flat.view(batch_size, seq_len, target_num_heads, target_head_dim)
        
        projected_key = projected_key_reshaped.transpose(1, 2)  # (B, H_t, N, D_t)
        projected_value = projected_value_reshaped.transpose(1, 2)  # (B, H_t, N, D_t)

        # Generate sequence-level weights using shared hidden layer + separate heads (like TokenWise)
        weight_hidden = self.weight_hidden(target_key)  # (B, hidden_dim)
        key_weight = self.key_weight_head(weight_hidden)  # (B, 1)
        value_weight = self.value_weight_head(weight_hidden)  # (B, 1)
        
        # Simple gate using parameter (simplified from complex network)
        if self.training:
            gate = self.gumbel_sigmoid_sample()
        else:
            gate = (self.gate_logit > 0).float()
        
        # Combine projected source with target using sequence-level weights
        normalized_key_weight = torch.sigmoid(key_weight / self.scalar_temperature)
        normalized_value_weight = torch.sigmoid(value_weight / self.scalar_temperature)
        
        output_key = (1 - normalized_key_weight) * target_key.clone() + gate * normalized_key_weight * projected_key
        output_value = (1 - normalized_value_weight) * target_value.clone() + gate * normalized_value_weight * projected_value
        
        return (output_key, output_value)


def save_projector(obj: Projector, file_path: str) -> None:
    save_object(obj, file_path)


def load_projector(file_path: str, override_args: Optional[dict] = None) -> Projector:
    return load_object(file_path, get_projector_class, override_args)


# Import all projector implementations to ensure they are registered
from . import all_in_one_projector