import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional, Tuple, List
from dataclasses import dataclass


@dataclass
class KSKTConfig:
    """Configuration class for KSKT model."""
    hidden_size: int = 2560
    num_attention_heads: int = 32
    num_kv_heads: int = 8
    head_dim: int = 128
    intermediate_size: int = 9728
    num_layers: int = 36
    max_position_embeddings: int = 262144
    rope_theta: float = 5000000.0
    num_experts: int = 4  
    thinking_budgets: List[int] = None  
    dsaa_layers: List[int] = None  
    fusion_eps: float = 1e-8
    
    def __post_init__(self):
        if self.thinking_budgets is None:
            self.thinking_budgets = [2, 4, 6, 8, 10]
        if self.dsaa_layers is None:
            self.dsaa_layers = [4, 8, 12, 16, 20, 24, 28, 32, 36]


class DualStreamAxialAttention(nn.Module):
    """
    Dual-Stream Axial Attention for factorizing self-understanding and other-understanding.
    """
    
    def __init__(self, config: KSKTConfig):
        super().__init__()
        self.config = config
        self.hidden_size = config.hidden_size
        self.num_heads = config.num_attention_heads
        self.num_kv_heads = config.num_kv_heads
        self.head_dim = config.head_dim
        self.eps = config.fusion_eps
        
        self.W_q_self = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
        self.W_k_self = nn.Linear(self.hidden_size, self.num_kv_heads * self.head_dim, bias=False)
        self.W_v_self = nn.Linear(self.hidden_size, self.num_kv_heads * self.head_dim, bias=False)
        
        self.W_q_other = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
        self.W_k_other = nn.Linear(self.hidden_size, self.num_kv_heads * self.head_dim, bias=False)
        self.W_v_other = nn.Linear(self.hidden_size, self.num_kv_heads * self.head_dim, bias=False)
        
        # Output projection
        self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
        
        self.bias_self_proj = nn.Linear(self.hidden_size, self.num_heads, bias=True)
        self.bias_other_proj = nn.Linear(self.hidden_size, self.num_heads, bias=True)
        
        self.fusion_w = nn.Linear(self.hidden_size, 1, bias=True)
        
        self.num_key_value_groups = self.num_heads // self.num_kv_heads
        
        self._init_weights()
    
    def _init_weights(self):
        """Xavier uniform initialization with scaling factor 0.02"""
        for module in [self.W_q_self, self.W_k_self, self.W_v_self,
                       self.W_q_other, self.W_k_other, self.W_v_other]:
            nn.init.xavier_uniform_(module.weight, gain=0.02)
    
    def _repeat_kv(self, hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
        """Repeat KV heads for GQA"""
        batch, num_kv_heads, seq_len, head_dim = hidden_states.shape
        if n_rep == 1:
            return hidden_states
        hidden_states = hidden_states[:, :, None, :, :].expand(
            batch, num_kv_heads, n_rep, seq_len, head_dim
        )
        return hidden_states.reshape(batch, num_kv_heads * n_rep, seq_len, head_dim)
    
    def _compute_stream_attention(
        self,
        H: torch.Tensor,
        W_q: nn.Linear,
        W_k: nn.Linear,
        W_v: nn.Linear,
        bias_proj: nn.Linear,
        position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
        attention_mask: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        """
        Compute attention for a single stream (self or other).
        """
        batch_size, seq_len, _ = H.shape
        
        # Project Q, K, V
        Q = W_q(H).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        K = W_k(H).view(batch_size, seq_len, self.num_kv_heads, self.head_dim).transpose(1, 2)
        V = W_v(H).view(batch_size, seq_len, self.num_kv_heads, self.head_dim).transpose(1, 2)
        
        if position_embeddings is not None:
            cos, sin = position_embeddings
            Q = self._apply_rotary_pos_emb(Q, cos, sin)
            K = self._apply_rotary_pos_emb(K, cos, sin)
        
        # Repeat K, V for GQA
        K = self._repeat_kv(K, self.num_key_value_groups)
        V = self._repeat_kv(V, self.num_key_value_groups)
        
        # Compute attention scores
        attn_weights = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.head_dim)
        
        # Add learned bias
        bias = bias_proj(H)  
        bias = bias.transpose(1, 2).unsqueeze(-1)  
        bias = bias.expand(-1, -1, -1, seq_len)  
        attn_weights = attn_weights + bias
        
        if attention_mask is not None:
            attn_weights = attn_weights + attention_mask
        
        # Softmax
        attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(V.dtype)
        
        # Compute attention output
        attn_output = torch.matmul(attn_weights, V)
        attn_output = attn_output.transpose(1, 2).contiguous()
        attn_output = attn_output.view(batch_size, seq_len, -1)
        
        return attn_output
    
    def _apply_rotary_pos_emb(
        self, 
        x: torch.Tensor, 
        cos: torch.Tensor, 
        sin: torch.Tensor
    ) -> torch.Tensor:
        """Apply Rotary Position Embedding."""
        x1, x2 = x[..., :x.shape[-1]//2], x[..., x.shape[-1]//2:]
        rotated = torch.cat([-x2, x1], dim=-1)
        return x * cos + rotated * sin
    
    def forward(
        self,
        H: torch.Tensor,
        R_proc: torch.Tensor,
        U_proc: torch.Tensor,
        position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
        attention_mask: Optional[torch.Tensor] = None,
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """
        Forward pass of Dual-Stream Axial Attention.
        """
        # Compute self-stream attention
        H_self = self._compute_stream_attention(
            H, self.W_q_self, self.W_k_self, self.W_v_self,
            self.bias_self_proj, position_embeddings, attention_mask
        )
        
        # Compute other-stream attention
        H_other = self._compute_stream_attention(
            H, self.W_q_other, self.W_k_other, self.W_v_other,
            self.bias_other_proj, position_embeddings, attention_mask
        )
        
        s = self.fusion_w(H)  
        alpha = torch.sigmoid(s) 
        beta = 1.0 - alpha  
        
        H_out = alpha * H_self + beta * H_other
        
        H_out = self.o_proj(H_out)
        
        return H_out, alpha, beta


class MutualUnderstandingPositionEncoding(nn.Module):
    """
    Mutual-Understanding Position Encoding augmenting RoPE with perspective-specific signals.
    """
    
    def __init__(self, config: KSKTConfig):
        super().__init__()
        self.config = config
        self.hidden_size = config.hidden_size
        self.head_dim = config.head_dim
        
        # f_self: MLP for self-relevance 
        self.mlp_self = nn.Sequential(
            nn.Linear(self.hidden_size + self.head_dim, self.hidden_size // 2),
            nn.ReLU(),
            nn.Linear(self.hidden_size // 2, self.hidden_size)
        )
        
        # f_other: MLP for other-relevance
        self.mlp_other = nn.Sequential(
            nn.Linear(self.hidden_size + self.head_dim, self.hidden_size // 2),
            nn.ReLU(),
            nn.Linear(self.hidden_size // 2, self.hidden_size)
        )
        
        # W_self and W_other projection matrices
        self.W_self = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
        self.W_other = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
        
        # Base RoPE components
        self.rope_theta = config.rope_theta
        self._init_rope()
    
    def _init_rope(self):
        """Initialize RoPE frequency tensor."""
        inv_freq = 1.0 / (self.rope_theta ** (
            torch.arange(0, self.head_dim, 2, dtype=torch.float32) / self.head_dim
        ))
        self.register_buffer("inv_freq", inv_freq, persistent=False)
    
    def _compute_base_rope(self, seq_len: int, device: torch.device) -> Tuple[torch.Tensor, torch.Tensor]:
        """Compute base RoPE embeddings."""
        positions = torch.arange(seq_len, device=device, dtype=torch.float32)
        freqs = torch.outer(positions, self.inv_freq)
        emb = torch.cat([freqs, freqs], dim=-1)
        cos = emb.cos()
        sin = emb.sin()
        return cos, sin
    
    def forward(
        self,
        seq_len: int,
        R_proc: torch.Tensor,
        U_proc: torch.Tensor,
        device: torch.device,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Compute mutual-understanding position encoding.
        """
        batch_size = R_proc.shape[0]
        
        # Compute base RoPE
        cos_base, sin_base = self._compute_base_rope(seq_len, device)
        
        # Create absolute position encoding for each position
        positions = torch.arange(seq_len, device=device, dtype=torch.float32)
        pos_enc = self._position_to_embedding(positions)  # [seq_len, head_dim]
        
        # Compute mean-pooled context representations
        R_mean = R_proc.mean(dim=1)   
        U_mean = U_proc.mean(dim=1) 
        
        pos_enc_expanded = pos_enc.unsqueeze(0).expand(batch_size, -1, -1)
        
        R_expanded = R_mean.unsqueeze(1).expand(-1, seq_len, -1)
        U_expanded = U_mean.unsqueeze(1).expand(-1, seq_len, -1)
        
        self_input = torch.cat([pos_enc_expanded, R_expanded], dim=-1)
        other_input = torch.cat([pos_enc_expanded, U_expanded], dim=-1)
        
        f_self = self.mlp_self(self_input)  
        f_other = self.mlp_other(other_input)  
        
        # Apply W_self and W_other projections
        mutual_self = self.W_self(f_self)
        mutual_other = self.W_other(f_other)
        
        mutual_signal = mutual_self + mutual_other  
        
        scale = torch.sigmoid(mutual_signal[..., :self.head_dim])  
        
        # Apply scaling to base RoPE
        cos = cos_base.unsqueeze(0) * (1 + 0.1 * (scale - 0.5))
        sin = sin_base.unsqueeze(0) * (1 + 0.1 * (scale - 0.5))
        
        return cos, sin
    
    def _position_to_embedding(self, positions: torch.Tensor) -> torch.Tensor:
        """Convert position indices to embedding vectors."""
        freqs = torch.outer(positions, self.inv_freq)
        emb = torch.cat([freqs.sin(), freqs.cos()], dim=-1)
        return emb


class BipolarReasoningModule(nn.Module):
    """
    Bipolar Reasoning Module combining fast (System 1) and slow (System 2) pathways.
    """
    
    def __init__(self, config: KSKTConfig):
        super().__init__()
        self.config = config
        self.hidden_size = config.hidden_size
        self.intermediate_size = config.intermediate_size
        self.thinking_budgets = config.thinking_budgets
        self.num_budgets = len(self.thinking_budgets)
        
        # RMSNorm for pre-normalization
        self.rmsnorm = RMSNorm(self.hidden_size)
        
        self.ffn_fast = SwiGLUFFN(self.hidden_size, self.intermediate_size)
        
        # Thinking chain components for System 2
        self.cross_attention = nn.MultiheadAttention(
            embed_dim=self.hidden_size,
            num_heads=config.num_attention_heads // 4, 
            batch_first=True
        )
        self.thinking_mlp = nn.Sequential(
            nn.Linear(self.hidden_size * 2, self.hidden_size * 4),
            nn.SiLU(),
            nn.Linear(self.hidden_size * 4, self.hidden_size)
        )
        self.thinking_norm = nn.LayerNorm(self.hidden_size)
        
        self.pre_gate = nn.Sequential(
            nn.Linear(self.hidden_size * 2, self.hidden_size),
            nn.ReLU(),
            nn.Linear(self.hidden_size, 1),
            nn.Sigmoid()
        )
    
        self.post_gate = nn.Sequential(
            nn.Linear(self.hidden_size * 3, self.hidden_size),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(self.hidden_size, self.hidden_size),
            nn.ReLU(),
            nn.Linear(self.hidden_size, 1),
            nn.Sigmoid()
        )
        
        
        self.budget_classifier = nn.Sequential(
            nn.Linear(self.hidden_size, self.hidden_size // 2),
            nn.ReLU(),
            nn.Linear(self.hidden_size // 2, self.num_budgets)
        )
        
        self.pre_gate_threshold = 0.5
    
    def _thinking_chain(
        self,
        H_dsaa: torch.Tensor,
        R_proc: torch.Tensor,
        U_proc: torch.Tensor,
        T: int
    ) -> torch.Tensor:
        """
        Execute thinking chain with T reasoning steps.
        """
        # Concatenate role and user context for cross-attention
        context = torch.cat([R_proc, U_proc], dim=1) 
        
        h = H_dsaa  
        
        for step in range(T):
            
            c, _ = self.cross_attention(h, context, context)
            
            
            h_c = torch.cat([h, c], dim=-1)
            t = self.thinking_mlp(h_c)
            
            
            h = self.thinking_norm(h + t)
        
        return h
    
    def forward(
        self,
        H_dsaa: torch.Tensor,
        R_proc: torch.Tensor,
        U_proc: torch.Tensor,
        return_metrics: bool = False
    ) -> Tuple[torch.Tensor, dict]:
        """
        Forward pass of Bipolar Reasoning Module.
        
        """
        batch_size, seq_len, _ = H_dsaa.shape
        metrics = {}
        
        
        H_normalized = self.rmsnorm(H_dsaa)
        H_fast = self.ffn_fast(H_normalized)
        
        
        pre_gate_input = torch.cat([
            H_dsaa.mean(dim=1),  
            H_fast.mean(dim=1)
        ], dim=-1)
        p_sys2 = self.pre_gate(pre_gate_input)  
        
        
        trigger_mask = (p_sys2 > self.pre_gate_threshold).squeeze(-1)  # [batch]
        
        if return_metrics:
            metrics['system2_trigger_rate'] = trigger_mask.float().mean().item()
        
        
        H_reason = H_fast.clone()
        
        if trigger_mask.any():
            
            triggered_indices = torch.where(trigger_mask)[0]
            
            
            H_dsaa_triggered = H_dsaa[triggered_indices]
            H_fast_triggered = H_fast[triggered_indices]
            R_proc_triggered = R_proc[triggered_indices]
            U_proc_triggered = U_proc[triggered_indices]
            
            
            budget_logits = self.budget_classifier(H_dsaa_triggered.mean(dim=1))
            budget_idx = budget_logits.argmax(dim=-1)  
            
            T = self.thinking_budgets[budget_idx.mode().values.item()]
            
            if return_metrics:
                metrics['thinking_budget'] = T
            
            
            H_slow_triggered = self._thinking_chain(
                H_dsaa_triggered, R_proc_triggered, U_proc_triggered, T
            )
            
            
            post_gate_input = torch.cat([
                H_dsaa_triggered, H_fast_triggered, H_slow_triggered
            ], dim=-1)
            g = self.post_gate(post_gate_input)  
            
            
            H_reason_triggered = g * H_fast_triggered + (1 - g) * H_slow_triggered
            
            
            H_reason[triggered_indices] = H_reason_triggered
            
            if return_metrics:
                metrics['gate_mean'] = g.mean().item()
        
        return H_reason, metrics

class SelfAwarenessMoE(nn.Module):
    """
    Self-Awareness Mixture of Experts with character-centric routing.
    
    """
    
    def __init__(self, config: KSKTConfig):
        super().__init__()
        self.config = config
        self.hidden_size = config.hidden_size
        self.intermediate_size = config.intermediate_size
        self.num_experts = config.num_experts
        
        self.expert_names = ['Personality', 'Knowledge', 'Emotional', 'Capability']
        
        self.experts = nn.ModuleList([
            SwiGLUFFN(self.hidden_size, self.intermediate_size)
            for _ in range(self.num_experts)
        ])
        

        for expert in self.experts:
            for param in expert.parameters():
                if param.dim() > 1:
                    nn.init.normal_(param, std=0.01)
        
        # Projection for computing attention over R_proc
        self.W_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
        
        # SwiGLU for q_self computation
        self.routing_swiglu = SwiGLUFFN(self.hidden_size * 2, self.hidden_size)
        
        self.router_weights = nn.Parameter(torch.randn(self.num_experts, self.hidden_size))
        self.router_biases = nn.Parameter(torch.zeros(self.num_experts))
        
        self.temperature = nn.Parameter(torch.ones(1))
        
        self.balance_weight = 0.01
    
    def _compute_routing_query(
        self,
        H_reason: torch.Tensor,
        R_proc: torch.Tensor
    ) -> torch.Tensor:
        """
        Compute self-reflective routing query q_self.
        """
        batch_size, seq_len, hidden = H_reason.shape
        n_r = R_proc.shape[1]
        
    
        R_projected = self.W_proj(R_proc)  
        attn_scores = torch.matmul(R_projected, H_reason.transpose(-2, -1))  
        attn_scores = attn_scores / math.sqrt(hidden)
        M_role = F.softmax(attn_scores, dim=-1) 
        
        h_role = torch.matmul(M_role.transpose(-2, -1), R_proc)  
        h_role = h_role.mean(dim=1)  
        
        h_context = R_proc.mean(dim=1) 
        
        h_concat = torch.cat([h_role, h_context], dim=-1)  
        q_self = self.routing_swiglu(h_concat.unsqueeze(1)).squeeze(1)  
        
        return q_self
    
    def _compute_routing_probs(self, q_self: torch.Tensor) -> torch.Tensor:
        """
        Compute expert routing probabilities.
        """
        logits = torch.matmul(q_self, self.router_weights.T) + self.router_biases  
        
        tau = torch.clamp(self.temperature, min=0.1, max=2.0)
        logits = logits / tau
        
        routing_probs = F.softmax(logits, dim=-1)
        
        return routing_probs
    
    def _compute_load_balance_loss(self, routing_probs: torch.Tensor) -> torch.Tensor:
        """
        Compute load balance loss to prevent expert collapse.
        """
        # Compute fraction of samples using each expert
        f = routing_probs.mean(dim=0)  

        target = 1.0 / self.num_experts
        
        balance_loss = ((f - target) ** 2).sum()
        
        return balance_loss
    
    def forward(
        self,
        H_reason: torch.Tensor,
        R_proc: torch.Tensor,
        return_metrics: bool = False
    ) -> Tuple[torch.Tensor, torch.Tensor, dict]:
        """
        Forward pass of Self-Awareness MoE.
        """
        batch_size, seq_len, hidden = H_reason.shape
        metrics = {}
        
        # Compute routing query
        q_self = self._compute_routing_query(H_reason, R_proc)
        
        # Compute routing probabilities
        routing_probs = self._compute_routing_probs(q_self)  
        
        if return_metrics:
            for i, name in enumerate(self.expert_names):
                metrics[f'expert_{name}_prob'] = routing_probs[:, i].mean().item()
        
        H_expert = torch.zeros_like(H_reason)
        
        for j in range(self.num_experts):
            # Get expert output
            expert_output = self.experts[j](H_reason)
            
            # Weight by routing probability 
            weight = routing_probs[:, j].view(batch_size, 1, 1)
            H_expert = H_expert + weight * expert_output
        
        # Compute load balance loss
        balance_loss = self._compute_load_balance_loss(routing_probs)
        
        return H_expert, balance_loss * self.balance_weight, metrics


class RMSNorm(nn.Module):
    """Root Mean Square Layer Normalization."""
    
    def __init__(self, hidden_size: int, eps: float = 1e-6):
        super().__init__()
        self.weight = nn.Parameter(torch.ones(hidden_size))
        self.eps = eps
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        variance = x.pow(2).mean(-1, keepdim=True)
        x = x * torch.rsqrt(variance + self.eps)
        return self.weight * x


class SwiGLUFFN(nn.Module):
    """
    SwiGLU Feed-Forward Network.
    """
    
    def __init__(self, hidden_size: int, intermediate_size: int):
        super().__init__()
        self.hidden_size = hidden_size
        self.intermediate_size = intermediate_size
        
        self.gate_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
        self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
        self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=False)
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        gate = F.silu(self.gate_proj(x))
        up = self.up_proj(x)
        return self.down_proj(gate * up)

class KSKTLayer(nn.Module):
    """
    Complete KSKT Layer combining all components.
    """
    
    def __init__(self, config: KSKTConfig, layer_idx: int):
        super().__init__()
        self.config = config
        self.layer_idx = layer_idx
        self.use_dsaa = (layer_idx + 1) in config.dsaa_layers
        
        # Pre-normalization
        self.input_layernorm = RMSNorm(config.hidden_size)
        self.post_attention_layernorm = RMSNorm(config.hidden_size)
        
        if self.use_dsaa:
            # DSAA for designated layers
            self.attention = DualStreamAxialAttention(config)
            self.mupe = MutualUnderstandingPositionEncoding(config)
            self.bipolar = BipolarReasoningModule(config)
            self.samoe = SelfAwarenessMoE(config)
        else:
            # Standard attention placeholder
            self.attention = None 
            self.ffn = SwiGLUFFN(config.hidden_size, config.intermediate_size)
    
    def forward(
        self,
        hidden_states: torch.Tensor,
        R_proc: torch.Tensor,
        U_proc: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        return_metrics: bool = False
    ) -> Tuple[torch.Tensor, dict]:
        """
        Forward pass of KSKT layer.
        """
        metrics = {}
        
        if self.use_dsaa:
            
            # Step 1: Compute MUPE
            seq_len = hidden_states.shape[1]
            position_embeddings = self.mupe(
                seq_len, R_proc, U_proc, hidden_states.device
            )
            
            # Step 2: Pre-norm
            normed_hidden = self.input_layernorm(hidden_states)
            
            # Step 3: Dual-Stream Axial Attention
            H_dsaa, alpha, beta = self.attention(
                normed_hidden, R_proc, U_proc,
                position_embeddings, attention_mask
            )
            
            # Residual connection
            H_dsaa = hidden_states + H_dsaa
            
            if return_metrics:
                metrics['alpha_mean'] = alpha.mean().item()
                metrics['beta_mean'] = beta.mean().item()
            
            # Step 4: Bipolar Reasoning
            H_reason, brm_metrics = self.bipolar(H_dsaa, R_proc, U_proc, return_metrics)
            metrics.update(brm_metrics)
            
            # Residual connection
            H_reason = H_dsaa + H_reason
            
            # Step 5: Self-Awareness MoE
            H_normed = self.post_attention_layernorm(H_reason)
            H_expert, balance_loss, moe_metrics = self.samoe(H_normed, R_proc, return_metrics)
            metrics.update(moe_metrics)
            metrics['balance_loss'] = balance_loss.item() if isinstance(balance_loss, torch.Tensor) else balance_loss
            
            # Residual connection
            output = H_reason + H_expert
            
        else:
            normed_hidden = self.input_layernorm(hidden_states)
            attn_output = normed_hidden  
            hidden_states = hidden_states + attn_output
            
            normed_hidden = self.post_attention_layernorm(hidden_states)
            ffn_output = self.ffn(normed_hidden)
            output = hidden_states + ffn_output
        
        return output, metrics

class KSKTModel(nn.Module):
    """
    Complete KnowSelf-KnowOther Transformer Model.
    """
    
    def __init__(self, config: KSKTConfig):
        super().__init__()
        self.config = config
        
        # Embedding layer
        self.embed_tokens = nn.Embedding(151936, config.hidden_size) 
        
        # KSKT Layers
        self.layers = nn.ModuleList([
            KSKTLayer(config, i) for i in range(config.num_layers)
        ])
        
        # Final normalization
        self.norm = RMSNorm(config.hidden_size)
        
        # LM head
        self.lm_head = nn.Linear(config.hidden_size, 151936, bias=False)
    
    def forward(
        self,
        input_ids: torch.Tensor,
        R_proc: torch.Tensor,
        U_proc: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        return_metrics: bool = False
    ) -> Tuple[torch.Tensor, dict]:
        """
        Forward pass of KSKT model.
        """
        # Embed tokens
        hidden_states = self.embed_tokens(input_ids)
        
        all_metrics = {}
        
        # Process through layers
        for i, layer in enumerate(self.layers):
            hidden_states, metrics = layer(
                hidden_states, R_proc, U_proc,
                attention_mask, return_metrics
            )
            if return_metrics and metrics:
                for k, v in metrics.items():
                    all_metrics[f'layer_{i}_{k}'] = v
        
        # Final normalization
        hidden_states = self.norm(hidden_states)
        
        # LM head
        logits = self.lm_head(hidden_states)
        
        return logits, all_metrics


if __name__ == "__main__":
    # Configuration
    config = KSKTConfig(
        hidden_size=2560,
        num_attention_heads=32,
        num_kv_heads=8,
        head_dim=128,
        intermediate_size=9728,
        num_layers=36,
    )
    
    # Create model
    model = KSKTModel(config)
    print(f"KSKT Model Parameters: {sum(p.numel() for p in model.parameters()):,}")
    
    batch_size = 2
    seq_len = 512
    n_r = 64  # Role context length
    n_u = 128  # User intent length
    
    input_ids = torch.randint(0, 151936, (batch_size, seq_len))
    R_proc = torch.randn(batch_size, n_r, config.hidden_size)
    U_proc = torch.randn(batch_size, n_u, config.hidden_size)
    
    # Forward pass
    with torch.no_grad():
        logits, metrics = model(input_ids, R_proc, U_proc, return_metrics=True)
    
    print(f"Output shape: {logits.shape}")
    print(f"Sample metrics: {list(metrics.keys())[:5]}")